如何在C++中计算两个STL集合的交集大小

7
我被给予两个集合(来自<set>的std::set),我想知道它们的交集大小。我可以使用<algorithm>中的std::set_intersection,但我必须提供一个输出迭代器来将交集复制到其他容器中。
一种直接的方法是:
  set<int> s1{1,2,3,4,5};
  set<int> s2{4,5,6,7,8,9,0,1};

  vector<int> v;

  set_intersection(
      s1.begin(), s1.end(), s2.begin(), s2.end(),
      inserter(v, v.begin()));

接下来使用v.size()可以得到交集的大小。然而,即使我们不做任何操作,交集也必须被存储。

为了避免这种情况,我尝试实现一个虚拟输出迭代器类,仅用于计数而不进行赋值:

template<typename T>
class CountingOutputIterator {
 private:
  int* counter_;
  T dummy_;
 public:
  explicit CountingOutputIterator(int* counter) :counter_(counter) {}
  T& operator*() {return dummy_;}
  CountingOutputIterator& operator++() { // ++t
    (*counter_)++;
    return *this;
  }
  CountingOutputIterator operator++(int) { // t++
    CountingOutputIterator ret(*this);
    (*counter_)++;
    return ret;
  }
  bool operator==(const CountingOutputIterator& c) {
    return counter_ == c.counter_; // same pointer
  }
  bool operator!=(const CountingOutputIterator& c) {
    return !operator==(c);
  }
};

使用这个工具,我们可以做到

  set<int> s1{1,2,3,4,5};
  set<int> s2{4,5,6,7,8,9,0,1};

  int counter = 0;
  CountingOutputIterator<int> counter_it(&counter);
  set_intersection(
      s1.begin(), s1.end(), s2.begin(), s2.end(), counter_it);

之后计数器将持有交集的大小。

这段代码要多得多。我的问题是:

1)是否有一种标准(库)方法或标准技巧可在不存储整个交集的情况下获取交集的大小? 2)无论是否存在,使用自定义虚拟迭代器的方法是否可行?


1
似乎为了仅识别共同元素的数量而过于复杂化了。为什么不只使用循环呢? - Aldehir
非常奇怪,当你永远不会真正使用交集时,知道大小有什么意义? 你是否清楚地考虑过这个问题? 阅读此内容 - Hans Passant
与其使用自定义迭代器,创建一个具有insert()成员的自定义“容器”,并使用该容器与insert_iterator更为简单。 - Jonathan Wakely
1
@HansPassant 你觉得这有什么奇怪的?我能想到很多情况。基本上它是重叠区域的范畴。 - doetoe
@JonathanWakely 谢谢,我会考虑一下的。 - doetoe
@Aldehir 我也这么认为,然而,如果你想编写具有相同空间和时间要求的自定义代码,那么你不会轻易得到更简单的东西。 - doetoe
4个回答

19

编写循环以遍历两个集合以查找匹配元素并不难,或者您可以使用比自定义迭代器简单得多的方法:

struct Counter
{
  struct value_type { template<typename T> value_type(const T&) { } };
  void push_back(const value_type&) { ++count; }
  size_t count = 0;
};

template<typename T1, typename T2>
size_t intersection_size(const T1& s1, const T2& s2)
{
  Counter c;
  set_intersection(s1.begin(), s1.end(), s2.begin(), s2.end(), std::back_inserter(c));
  return c.count;
}

为了编译它(g++ 4.8.4),我必须将Counter作为T的结构体模板,并在其中嵌套一个value_type的typedef:使用value_type = T; - doetoe
啊,好的观点。我已经更新了答案,并提供了一种替代方法,仍然意味着Counter不需要成为一个模板:定义一个可以从任何东西构造的value_type - Jonathan Wakely
如果我想并行化它呢?使用 std::execution::par 呢。 :| - undefined

3
你可以这样做:
auto common = count_if(begin(s1), end(s1), [&](const auto& x){ return s2.find(x) != end(s2); });

这不是最优化的效率,但对于大多数用途来说速度足够快。


难道不应该与 s2.end() 进行比较吗? - Cheers and hth. - Alf
谢谢。但是,这个算法的计算复杂度较高,即$n\log n$与$m + \log n$,其中$n$表示集合的大小,$m$表示它们的交集。 - doetoe
@doetoe 是的,我发布后进行了编辑以提及这一点。这个解决方案偏向于简洁而非效率。Mats Petersson的答案是最有效的方法,或者您的自定义迭代器解决方案。 - mattnewport
@doetoe 这正是Eric Niebler提出的Range v3库扩展所擅长的事情。 - mattnewport
1
如果s2比s1大很多,那么复杂度n1*log(n2)可能会比n1+n2小。 - Marc Glisse
显示剩余2条评论

3

写一个实现这个功能的函数并不太难。 这里 展示了如何实现 set_intersection [尽管实际实现可能略有不同]。

因此,我们可以直接拿那段代码,并稍微修改一下:

template <class InputIterator1, class InputIterator2>
  size_t set_intersection_size (InputIterator1 first1, InputIterator1 last1,
                                InputIterator2 first2, InputIterator2 last2)
{
  size_t result = 0;
  while (first1!=last1 && first2!=last2)
  {
    if (*first1<*first2) ++first1;
    else if (*first2<*first1) ++first2;
    else {
      result++;
      ++first1; ++first2;
    }
  }
  return result;
}

尽管根据我的经验,当你想知道交集中有多少个元素时,通常迟早会想知道哪些元素也在其中。


3
您可以简化您的方法使用:
struct counting_iterator
{
    size_t count;
    counting_iterator& operator++() { ++count; return *this; }

    struct black_hole { void operator=(T) {} };
    black_hole operator*() { return black_hole(); }

    // other iterator stuff may be needed
};

size_t count = set_intersection(
  s1.begin(), s1.end(), s2.begin(), s2.end(), counting_iterator()).count;

好的,谢谢。我一开始就是这样做的,但后来我想我可能无法访问传递给set_intersection的迭代器。但当然,只是返回一个副本。 - doetoe

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接