这基本上可以满足你的需求,如果不行的话可以根据你的要求进行修改。
在Linux上,它使用`pthread_key_create`,而在Windows上则使用`TlsAlloc`。它们都是通过“键”来检索线程本地数据的一种方式。但是,如果你注册了这些键,你就可以在其他线程上访问数据。
EnumerableThreadLocal的思想是在你的线程中执行本地操作,然后在主线程中将结果合并起来。
tbb有一个类似的函数叫做enumerable_thread_specific,关于它的动机可以在
https://oneapi-src.github.io/oneTBB/main/tbb_userguide/design_patterns/Divide_and_Conquer.html找到。
下面的代码是在不依赖tbb的情况下模仿tbb代码的尝试。但是,下面的代码的缺点是在Windows上只能使用1088个键。
template <typename T>
class EnumerableThreadLocal
{
#if _WIN32 || _WIN64
using tls_key_t = DWORD;
void create_key() { my_key = TlsAlloc(); }
void destroy_key() { TlsFree(my_key); }
void set_tls(void *value) { TlsSetValue(my_key, (LPVOID)value); }
void *get_tls() { return (void *)TlsGetValue(my_key); }
#else
using tls_key_t = pthread_key_t;
void create_key() { pthread_key_create(&my_key, nullptr); }
void destroy_key() { pthread_key_delete(my_key); }
void set_tls(void *value) const { pthread_setspecific(my_key, value); }
void *get_tls() const { return pthread_getspecific(my_key); }
#endif
std::vector<std::pair<std::thread::id, std::unique_ptr<T>>> m_thread_locals;
std::mutex m_mtx;
tls_key_t my_key;
using Factory = std::function<std::unique_ptr<T>()>;
Factory m_factory;
static auto DefaultFactory()
{
return std::make_unique<T alignas(hardware_constructive_interference_size)>();
}
public:
EnumerableThreadLocal(Factory factory = &DefaultFactory ) : m_factory(factory)
{
create_key();
}
~EnumerableThreadLocal()
{
destroy_key();
}
EnumerableThreadLocal(const EnumerableThreadLocal &other)
{
create_key();
m_thread_locals.reserve(other.m_thread_locals.size());
for (const auto &pair : other.m_thread_locals)
{
m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
}
}
EnumerableThreadLocal &operator=(const EnumerableThreadLocal &other)
{
if (this != &other)
{
destroy_key();
create_key();
m_thread_locals.clear();
m_thread_locals.reserve(other.m_thread_locals.size());
for (const auto &pair : other.m_thread_locals)
{
m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
}
}
return *this;
}
EnumerableThreadLocal(EnumerableThreadLocal &&other) noexcept
{
my_key = other.my_key;
m_thread_locals = std::move(other.m_thread_locals);
other.my_key = 0;
}
EnumerableThreadLocal &operator=(EnumerableThreadLocal &&other) noexcept
{
if (this != &other)
{
destroy_key();
my_key = other.my_key;
m_thread_locals = std::move(other.m_thread_locals);
other.my_key = 0;
}
return *this;
}
T *Get ()
{
void *v = get_tls();
if (v)
{
return reinterpret_cast<T *>(v);
}
else
{
const std::scoped_lock l(m_mtx);
for (const auto &[thread_id, uptr] : m_thread_locals)
{
if (thread_id == std::this_thread::get_id())
{
set_tls(reinterpret_cast<void *>(uptr.get()));
return uptr.get();
}
}
m_thread_locals.emplace_back(std::this_thread::get_id(), m_factory());
T *ptr = m_thread_locals.back().second.get();
set_tls(reinterpret_cast<void *>(ptr));
return ptr;
}
}
T const * Get() const
{
return const_cast<EnumerableThreadLocal *>(this)->Get();
}
T & operator *()
{
return *Get();
}
T const & operator *() const
{
return *Get();
}
T * operator ->()
{
return Get();
}
T const * operator ->() const
{
return Get();
}
template <typename F>
void Enumerate(F fn)
{
const std::scoped_lock lock(m_mtx);
for (auto &[thread_id, ptr] : m_thread_locals)
fn(*ptr);
}
};
一个测试套件,展示给你它的工作原理。
#include <thread>
#include <string>
#include "gtest/gtest.h"
#include "EnumerableThreadLocal.hpp"
TEST(EnumerableThreadLocal, BasicTest)
{
const int N = 10;
v31::EnumerableThreadLocal<std::string> tls;
std::vector<std::thread> threads;
for (int i = 0; i < N; ++i)
{
threads.emplace_back([&tls, i]()
{ *tls = "Thread " + std::to_string(i); });
}
for (auto &thread : threads)
thread.join();
std::vector<std::string> expected;
tls.Enumerate([&](std::string &s)
{ expected.push_back(s); });
std::sort(expected.begin(), expected.end());
for (int i = 0; i < N; ++i)
{
ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
}
}
struct NonCopyable
{
int i=0;
NonCopyable() = default;
NonCopyable(const NonCopyable &) = delete;
NonCopyable(NonCopyable &&) = delete;
NonCopyable &operator=(const NonCopyable &) = delete;
NonCopyable &operator=(NonCopyable &&) = delete;
};
TEST(EnumerableThreadLocal, NonCopyableTest)
{
const int N = 10;
v31::EnumerableThreadLocal<NonCopyable> tls;
std::vector<std::thread> threads;
for (int i = 0; i < N; ++i)
{
threads.emplace_back([&tls, i]()
{ tls->i=i; });
}
for (auto &thread : threads)
thread.join();
std::vector<int> expected;
tls.Enumerate([&](NonCopyable &s)
{ expected.push_back(s.i); });
std::sort(expected.begin(), expected.end());
for (int i = 0; i < N; ++i)
{
ASSERT_EQ(expected[i], i);
}
}
const int N = 10;
v31::EnumerableThreadLocal<std::string> CreateFixture()
{
v31::EnumerableThreadLocal<std::string> tls;
std::vector<std::thread> threads;
for (int i = 0; i < N; ++i)
{
threads.emplace_back([&tls, i]()
{ *tls = "Thread " + std::to_string(i); });
}
for (auto &thread : threads)
thread.join();
return tls;
}
void CheckFixtureCopy(v31::EnumerableThreadLocal<std::string> & tls)
{
std::vector<std::string> expected;
tls.Enumerate([&](std::string &s)
{ expected.push_back(s); });
std::sort(expected.begin(), expected.end());
for (int i = 0; i < N; ++i)
{
ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
}
}
void CheckFixtureEmpty(v31::EnumerableThreadLocal<std::string> & tls)
{
std::vector<std::string> expected;
tls.Enumerate([&](std::string &s)
{ expected.push_back(s); });
ASSERT_EQ(expected.size(), 0);
}
TEST(EnumerableThreadLocal, Copy)
{
auto tls = CreateFixture();
auto tls_copy = tls;
CheckFixtureCopy(tls_copy);
CheckFixtureCopy(tls);
}
TEST(EnumerableThreadLocal, Move)
{
auto tls = CreateFixture();
auto tls_copy = std::move(tls);
CheckFixtureCopy(tls_copy);
CheckFixtureEmpty(tls);
}
TEST(EnumerableThreadLocal, CopyAssign)
{
auto tls = CreateFixture();
v31::EnumerableThreadLocal<std::string> tls_copy;
CheckFixtureEmpty(tls_copy);
tls_copy = tls;
CheckFixtureCopy(tls_copy);
CheckFixtureCopy(tls);
}
TEST(EnumerableThreadLocal, MoveAssign)
{
auto tls = CreateFixture();
v31::EnumerableThreadLocal<std::string> tls_copy;
CheckFixtureEmpty(tls_copy);
tls_copy = std::move(tls);
CheckFixtureCopy(tls_copy);
CheckFixtureEmpty(tls);
}
struct NoDefaultConstructor
{
int i;
NoDefaultConstructor(int i) : i(i) {}
};
TEST(EnumerableThreadLocal, NoDefaultConstructor)
{
const int N = 10;
v31::EnumerableThreadLocal<NoDefaultConstructor> tls([]{return std::make_unique<NoDefaultConstructor>(0);});
std::vector<std::thread> threads;
for (int i = 0; i < N; ++i)
{
threads.emplace_back([&tls, i]()
{ tls->i = i; });
}
for (auto &thread : threads)
thread.join();
std::vector<int> expected;
tls.Enumerate([&](NoDefaultConstructor &s)
{ expected.push_back(s.i); });
std::sort(expected.begin(), expected.end());
for (int i = 0; i < N; ++i)
{
ASSERT_EQ(expected[i], i);
}
}
__thread
告诉编译器它可以使用 CPU 寄存器。如果是这样的话,由于硬件限制,直接访问将变得不可能。 - SpliFF