C++使用自定义类类型作为键的unordered_map

410

我正在尝试将自定义类用作unordered_map的键,如下所示:

#include <iostream>
#include <algorithm>
#include <unordered_map>

using namespace std;

class node;
class Solution;

class Node {
public:
    int a;
    int b; 
    int c;
    Node(){}
    Node(vector<int> v) {
        sort(v.begin(), v.end());
        a = v[0];       
        b = v[1];       
        c = v[2];       
    }

    bool operator==(Node i) {
        if ( i.a==this->a && i.b==this->b &&i.c==this->c ) {
            return true;
        } else {
            return false;
        }
    }
};

int main() {
    unordered_map<Node, int> m;    

    vector<int> v;
    v.push_back(3);
    v.push_back(8);
    v.push_back(9);
    Node n(v);

    m[n] = 0;

    return 0;
}

然而,g++ 给了我以下错误:

In file included from /usr/include/c++/4.6/string:50:0,
                 from /usr/include/c++/4.6/bits/locale_classes.h:42,
                 from /usr/include/c++/4.6/bits/ios_base.h:43,
                 from /usr/include/c++/4.6/ios:43,
                 from /usr/include/c++/4.6/ostream:40,
                 from /usr/include/c++/4.6/iostream:40,
                 from 3sum.cpp:4:
/usr/include/c++/4.6/bits/stl_function.h: In member function ‘bool std::equal_to<_Tp>::operator()(const _Tp&, const _Tp&) const [with _Tp = Node]’:
/usr/include/c++/4.6/bits/hashtable_policy.h:768:48:   instantiated from ‘bool std::__detail::_Hash_code_base<_Key, _Value, _ExtractKey, _Equal, _H1, _H2, std::__detail::_Default_ranged_hash, false>::_M_compare(const _Key&, std::__detail::_Hash_code_base<_Key, _Value, _ExtractKey, _Equal, _H1, _H2, std::__detail::_Default_ranged_hash, false>::_Hash_code_type, std::__detail::_Hash_node<_Value, false>*) const [with _Key = Node, _Value = std::pair<const Node, int>, _ExtractKey = std::_Select1st<std::pair<const Node, int> >, _Equal = std::equal_to<Node>, _H1 = std::hash<Node>, _H2 = std::__detail::_Mod_range_hashing, std::__detail::_Hash_code_base<_Key, _Value, _ExtractKey, _Equal, _H1, _H2, std::__detail::_Default_ranged_hash, false>::_Hash_code_type = long unsigned int]’
/usr/include/c++/4.6/bits/hashtable.h:897:2:   instantiated from ‘std::_Hashtable<_Key, _Value, _Allocator, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, __cache_hash_code, __constant_iterators, __unique_keys>::_Node* std::_Hashtable<_Key, _Value, _Allocator, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, __cache_hash_code, __constant_iterators, __unique_keys>::_M_find_node(std::_Hashtable<_Key, _Value, _Allocator, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, __cache_hash_code, __constant_iterators, __unique_keys>::_Node*, const key_type&, typename std::_Hashtable<_Key, _Value, _Allocator, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, __cache_hash_code, __constant_iterators, __unique_keys>::_Hash_code_type) const [with _Key = Node, _Value = std::pair<const Node, int>, _Allocator = std::allocator<std::pair<const Node, int> >, _ExtractKey = std::_Select1st<std::pair<const Node, int> >, _Equal = std::equal_to<Node>, _H1 = std::hash<Node>, _H2 = std::__detail::_Mod_range_hashing, _Hash = std::__detail::_Default_ranged_hash, _RehashPolicy = std::__detail::_Prime_rehash_policy, bool __cache_hash_code = false, bool __constant_iterators = false, bool __unique_keys = true, std::_Hashtable<_Key, _Value, _Allocator, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, __cache_hash_code, __constant_iterators, __unique_keys>::_Node = std::__detail::_Hash_node<std::pair<const Node, int>, false>, std::_Hashtable<_Key, _Value, _Allocator, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, __cache_hash_code, __constant_iterators, __unique_keys>::key_type = Node, typename std::_Hashtable<_Key, _Value, _Allocator, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, __cache_hash_code, __constant_iterators, __unique_keys>::_Hash_code_type = long unsigned int]’
/usr/include/c++/4.6/bits/hashtable_policy.h:546:53:   instantiated from ‘std::__detail::_Map_base<_Key, _Pair, std::_Select1st<_Pair>, true, _Hashtable>::mapped_type& std::__detail::_Map_base<_Key, _Pair, std::_Select1st<_Pair>, true, _Hashtable>::operator[](const _Key&) [with _Key = Node, _Pair = std::pair<const Node, int>, _Hashtable = std::_Hashtable<Node, std::pair<const Node, int>, std::allocator<std::pair<const Node, int> >, std::_Select1st<std::pair<const Node, int> >, std::equal_to<Node>, std::hash<Node>, std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, std::__detail::_Prime_rehash_policy, false, false, true>, std::__detail::_Map_base<_Key, _Pair, std::_Select1st<_Pair>, true, _Hashtable>::mapped_type = int]’
3sum.cpp:149:5:   instantiated from here
/usr/include/c++/4.6/bits/stl_function.h:209:23: error: passing ‘const Node’ as ‘this’ argument of ‘bool Node::operator==(Node)’ discards qualifiers [-fpermissive]
make: *** [threeSum] Error 1

我猜,我需要告诉C++如何散列类Node,但我不是很确定该怎么做。我该如何完成这个任务?


3
第三个模板参数是您需要提供的哈希函数。参考链接 - chrisaycock
6
cppreference提供了一个简单实用的例子来演示如何做到这一点:http://en.cppreference.com/w/cpp/container/unordered_map/unordered_map - jogojapan
8个回答

667
为了能够使用 std::unordered_map (或其他无序关联容器)与自定义键类型一起使用,需要定义两个东西:
  1. 一个哈希函数;这必须是一个覆盖 operator() 并计算给定键类型的对象的哈希值的类。其中一种特别简单的方法是为您的键类型专门化 std::hash 模板。

  2. 一个相等比较函数;这是必需的,因为哈希不能依赖于哈希函数总是为每个不同的键提供唯一的哈希值(即,它需要能够处理冲突),因此它需要一种方法来比较两个给定的键是否完全匹配。您可以将其实现为覆盖 operator() 的类,也可以将其实现为 std::equal 的专业化,或者 - 最简单的 - 通过为键类型重载 operator==()(如您已经做过的那样)。

哈希函数的难点在于,如果您的键类型由多个成员组成,通常会让哈希函数计算每个成员的哈希值,然后以某种方式将它们合并为整个对象的一个哈希值。为了获得良好的性能(即少冲突),您应该仔细考虑如何组合单个哈希值,以确保避免不同对象太经常地获得相同的输出。

一个相当不错的哈希函数的起点是使用位移和按位异或来组合单个哈希值。例如,假设有这样一种键类型:

struct Key
{
  std::string first;
  std::string second;
  int         third;

  bool operator==(const Key &other) const
  { return (first == other.first
            && second == other.second
            && third == other.third);
  }
};

这里是一个简单的哈希函数(改编自 cppreference用户定义哈希函数示例 中使用的函数):
template <>
struct std::hash<Key>
{
  std::size_t operator()(const Key& k) const
  {
    using std::size_t;
    using std::hash;
    using std::string;

    // Compute individual hash values for first,
    // second and third and combine them using XOR
    // and bit shifting:

    return ((hash<string>()(k.first)
             ^ (hash<string>()(k.second) << 1)) >> 1)
             ^ (hash<int>()(k.third) << 1);
  }
};

有了这个,你可以为 key-type 实例化一个 std::unordered_map

int main()
{
  std::unordered_map<Key,std::string> m6 = {
    { {"John", "Doe", 12}, "example"},
    { {"Mary", "Sue", 21}, "another"}
  };
}

它将自动使用上面定义的std::hash<Key>进行哈希值计算,并使用Key的成员函数作为operator==进行相等性检查。

如果您不想在std命名空间内专门化模板(尽管在这种情况下完全合法),您可以将哈希函数定义为单独的类并将其添加到映射的模板参数列表中:

struct KeyHasher
{
  std::size_t operator()(const Key& k) const
  {
    using std::size_t;
    using std::hash;
    using std::string;

    return ((hash<string>()(k.first)
             ^ (hash<string>()(k.second) << 1)) >> 1)
             ^ (hash<int>()(k.third) << 1);
  }
};

int main()
{
  std::unordered_map<Key,std::string,KeyHasher> m6 = {
    { {"John", "Doe", 12}, "example"},
    { {"Mary", "Sue", 21}, "another"}
  };
}

如何定义更好的哈希函数?如上所述,定义一个良好的哈希函数对于避免冲突和获得良好的性能非常重要。要想得到一个真正好的哈希函数,您需要考虑所有字段可能值的分布,并定义一个将该分布投影到可能结果空间尽可能广泛和均匀分布的哈希函数。
这可能很困难;上面的XOR/位移方法可能不是一个坏的开始。为了稍微更好的开始,您可以使用Boost库中的hash_valuehash_combine函数模板。前者的作用类似于标准类型(最近还包括元组和其他有用的标准类型)的std::hash;后者帮助您将单个哈希值组合成一个。以下是使用Boost辅助函数重写的哈希函数:
#include <boost/functional/hash.hpp>

struct KeyHasher
{
  std::size_t operator()(const Key& k) const
  {
      using boost::hash_value;
      using boost::hash_combine;

      // Start with a hash value of 0    .
      std::size_t seed = 0;

      // Modify 'seed' by XORing and bit-shifting in
      // one member of 'Key' after the other:
      hash_combine(seed,hash_value(k.first));
      hash_combine(seed,hash_value(k.second));
      hash_combine(seed,hash_value(k.third));

      // Return the result.
      return seed;
  }
};

这里有一个重写版本,它不使用boost,但使用了一种很好的方法来组合哈希值:

template <>
struct std::hash<Key>
{
    std::size_t operator()( const Key& k ) const
    {
        // Compute individual hash values for first, second and third
        // https://dev59.com/enI-5IYBdhLWcg3w18V3#1646913
        std::size_t res = 17;
        res = res * 31 + hash<string>()( k.first );
        res = res * 31 + hash<string>()( k.second );
        res = res * 31 + hash<int>()( k.third );
        return res;
    }
};

14
请问为什么需要在 KeyHasher 中移位(shift the bits)呢? - Chani
67
如果没有对位进行移动,两个字符串相同时,异或运算会使它们互相抵消。因此,hash(“a”,“a”,1)与hash(“b”,“b”,1)相同。同时,顺序也不重要,所以hash(“a”,“b”,1)与hash(“b”,“a”,1)相同。 - Buge
1
我正在学习C++,但有一件事情一直困扰着我:代码应该放在哪里?就像你所做的那样,我为我的键编写了一个专门的std::hash方法。我把它放在了Key.cpp文件的底部,但是我遇到了以下错误:Error 57 error C2440: 'type cast' : cannot convert from 'const Key' to 'size_t' c:\program files (x86)\microsoft visual studio 10.0\vc\include\xfunctional。我猜测编译器没有找到我的哈希方法?我应该在我的Key.h文件中添加什么吗? - Ben
7
把它放到.h文件中是正确的。std::hash实际上不是一个结构体,而是一个结构体的模板(特化)。因此它不是一种实现--当编译器需要时,它将被转换为一种实现。模板应该总是放在头文件中。另请参阅https://dev59.com/O3RB5IYBdhLWcg3w1Kr0 - jogojapan
3
find()返回一个迭代器,该迭代器指向map的一个"entry"。一个entry是一个std::pair类型,包含键和值。因此,如果您执行auto iter = m6.find({"John","Doe",12});,您将在iter->first中得到键,以及在iter->second中得到值(即字符串"example")。如果您想直接获取字符串,则可以使用m6.at({"John","Doe",12})(如果该键不存在则会抛出异常),或者使用m6[{"John","Doe",12}](如果该键不存在,则会创建一个空的值)。 - jogojapan
显示剩余14条评论

34

我认为,jogojapan给出了非常好而详尽的解答。在阅读我的帖子之前,你绝对应该先看一下它。

  1. 可以单独为unordered_map定义一个比较函数,而不是使用相等比较运算符(operator==)。例如,如果您想将后者用于将两个Node对象的所有成员相互比较,但只有一些特定成员作为unordered_map的键,则这可能很有帮助。
  2. 还可以使用lambda表达式来定义哈希和比较函数。

总的来说,对于您的Node类,代码可以编写如下:

using h = std::hash<int>;
auto hash = [](const Node& n){return ((17 * 31 + h()(n.a)) * 31 + h()(n.b)) * 31 + h()(n.c);};
auto equal = [](const Node& l, const Node& r){return l.a == r.a && l.b == r.b && l.c == r.c;};
std::unordered_map<Node, int, decltype(hash), decltype(equal)> m(8, hash, equal);

注意事项:

  • 我只是重用了 jogojapan 的答案结尾处的哈希方法,但你可以在这里找到更一般解决方案的思路(如果你不想使用 Boost)。
  • 我的代码可能有点过于压缩。为了更易读的版本,请参阅Ideone 上的这个代码

2
8从哪里来?它代表什么意思? - AndiChin
@WhalalalalalalaCHen:请查看unordered_map构造函数的文档。其中的8代表所谓的“桶计数”。一个桶是容器内部哈希表中的一个插槽,更多信息请参见unordered_map::bucket_count - honk
1
@WhalalalalalalaCHen:我随机选择了8。根据您想要存储在unordered_map中的内容,桶计数可以影响容器的性能。 - honk

12

使用自定义类作为 unordered_map 的键(稀疏矩阵的基本实现)的最基本的可复制/粘贴完整可运行示例:

// UnorderedMapObjectAsKey.cpp

#include <iostream>
#include <vector>
#include <unordered_map>

struct Pos
{
  int row;
  int col;

  Pos() { }
  Pos(int row, int col)
  {
    this->row = row;
    this->col = col;
  }

  bool operator==(const Pos& otherPos) const
  {
    if (this->row == otherPos.row && this->col == otherPos.col) return true;
    else return false;
  }

  struct HashFunction
  {
    size_t operator()(const Pos& pos) const
    {
      size_t rowHash = std::hash<int>()(pos.row);
      size_t colHash = std::hash<int>()(pos.col) << 1;
      return rowHash ^ colHash;
    }
  };
};

int main(void)
{
  std::unordered_map<Pos, int, Pos::HashFunction> umap;

  // at row 1, col 2, set value to 5
  umap[Pos(1, 2)] = 5;

  // at row 3, col 4, set value to 10
  umap[Pos(3, 4)] = 10;

  // print the umap
  std::cout << "\n";
  for (auto& element : umap)
  {
    std::cout << "( " << element.first.row << ", " << element.first.col << " ) = " << element.second << "\n";
  }
  std::cout << "\n";

  return 0;
}

你如何在不提供unordered_map的第三个参数的情况下完成这个操作? - rare77

3

对于枚举类型,我认为这是一种合适的方式,与类的区别在于如何计算哈希值。

template <typename T>
struct EnumTypeHash {
  std::size_t operator()(const T& type) const {
    return static_cast<std::size_t>(type);
  }
};

enum MyEnum {};
class MyValue {};

std::unordered_map<MyEnum, MyValue, EnumTypeHash<MyEnum>> map_;

这正是我一直在寻找的!!!非常感谢Jiaqi Ju! - Carlos Linares López

1
STL没有为pair提供哈希函数。您需要自己实现它,并将其指定为模板参数或放入std命名空间中,从那里它会被自动捕捉到。以下https://github.com/HowardHinnant/hash_append/blob/master/n3876.h非常适用于实现自定义哈希函数以用于结构体。更多细节在本问题的其他答案中有很好的解释,所以我不会重复。在Boost中也有类似的东西(hash_combine)。

0

0
这里的答案非常有帮助,但我仍然在尝试弄清楚这个问题时遇到了很大的困难,因此我的经验教训可能会有所裨益。与OP相比,我的情况有点独特;我的key是一个自定义的UUID类,而且我并不拥有它。在我看来,这个类存在一个错误/疏忽,它没有定义哈希函数,也没有重载operator()(它确实定义了operator==,所以我已经解决了这个问题)。是的,我有源代码,但它被广泛分发和控制,所以修改它是不可行的。我想将这个UUID作为std::unordered_map成员中的键使用,就像这样:

std::unordered_map<UUID, MyObject> mapOfObjs_;

在Visual Studio中,最终我选择了这个解决方案:
// file MyClass.h

namespace myNamespace
{
   static auto staticUuidHashFunc = [](const UUID& n)
   {
      // XORed the most and least significant bits, not important
   }
   ... 
   class MyClass
   {
      ...
   private:
      std::unordered_map<UUID, std::unique_ptr<MyObject>, decltype(staticUuidHashFunc)> mapOfObjs_;
   };
}

在Windows中这个很好用。但是,当我最终将我的代码带到Linux的gcc时,我收到了警告(引述)

'MyClass'有一个字段'mapOfObjs_',其类型使用匿名命名空间

即使我禁用了所有警告,我仍然收到了这个警告,所以gcc必须认为它非常严重。我在Google上搜索并找到了this answer,它建议我需要将哈希函数代码移动到.cpp文件中。

此时,我还尝试从UUID类派生:

// file MyClass.h

namespace myNamespace
{
   struct myUuid : public UUID
   {
      // overload the operator()
   };
   ...
   // and change my map to use this type
   std::unordered_map<myUuid, std::unique_ptr<MyObject>> mapOfObjs_;
}

然而,这也带来了一系列问题。即,所有使用(现在是父类)UUID类的代码部分都与我的地图不兼容,例如:

void MyClass::FindUuid(const UUID& id)
{
   // doesn't work, can't convert `id` to a `myUuid` type
   auto it = mapOfObjs_.find(id);
   ...
}

现在出了问题。我不想改变所有的代码,所以我放弃了那个方法,回到了“将代码放入.cpp文件”的解决方案。然而,我还是固执地尝试了一些方法来保持哈希函数在.h文件中。我真正想避免的是从哈希函数定义中删除auto,因为我不知道也不想弄清楚类型是什么。所以我尝试了:

class MyClass
{
   ...
private:
   static auto staticUuidHashFunc = [](const UUID& n)
   {
      // my hash function
   }
};

但是这个(或者它的变体)会返回错误,比如“类中不能有静态初始化器”,“这里不能使用auto”等等(我需要严格的C++11要求)。所以我最终接受了我需要像对待一个static变量一样来处理它,在头文件中声明,在.cpp文件中初始化。一旦我弄清楚了它的类型,就很容易了:
// MyClass.h
namespace myNamespace
{
   class MyClass
   {
      ...
   private:
      static std::function<unsigned long long(const UUID&)> staticUuidHashFunc;

      std::unordered_map<UUID, std::unique_ptr<MyObject>, decltype(staticUuidHashFunc)> mapOfObjs_;
   };
}

最后,在 .cpp 文件中:
// MyClass.cpp

namespace myNamespace
{
   std::function<unsigned long long(const UUID&)> MyClass::staticUuidHashFunc = [](const UUID& n)
    {
        // the hash function
    };

    MyClass::MyClass()
       : mapOfObjs_{ std::unordered_map<UUID, std::unique_ptr<MyObject>, decltype(staticUuidHashFunc)> (MyClass::NUMBER_OF_MAP_BUCKETS, staticUuidHashFunc)}
    {  }

   ...
}

在.cpp文件中定义静态哈希函数是关键。之后,Visual Studio和gcc都很满意。

0
我们需要做两点:
1. 最小化碰撞 2. hash(a,b)!=hash(b,a)
我建议这样做:

enter image description here

我们在数值上有这样的算法:
a = hash1
b = hash2
p = a+b
result_hash = p*(p+1)/2+b

但是如果我们简化算法,这不会影响碰撞。
a = hash1
b = hash2
p = a+b
result_hash = p*p+b

我们可以用Python来检查这个(代码):
N=256 # number of different values of hash
d = {}
for a in range(N):
  for b in range(N):
    p = a+b
    x = (p*p+b)%N
    if x in d:
      d[x]+=1
    else:
      d[x]=1
    #print(format(x,'04b'),end='\t')
  #print()
print(max(v for k,v in d.items()))

终于我们有了这样的方法:
p = a+b
return p*p+b

return (a<<1)^b

return a*31+b

return (17*31+a)*31+b

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接