从另一个线程访问线程本地变量

25
如何从另一个线程读取/写入线程本地变量?也就是说,在线程A中,我想访问线程B的线程本地存储区域中的变量。我知道另一个线程的ID。在GCC中,该变量声明为__thread。目标平台是Linux,但独立性可能很好(特定于GCC也可以)。
由于缺少线程启动钩子,因此我无法简单地跟踪每个线程的值。所有线程都需要以这种方式进行跟踪(而不仅仅是特别启动的线程)。
像boost thread_local_storage或使用pthread键这样的更高级别封装不是选项。我需要使用真正的__thread本地变量的性能。
首个答案是错误的:对于我想要做的事情,不能使用全局变量。每个线程必须有自己的变量副本。此外,出于性能原因,这些变量必须是__thread变量(同样有效的解决方案也可以,但我不知道)。我还无法控制线程入口点,因此这些线程没有可能注册任何类型的结构。
线程本地不是私有的:关于线程本地变量的另一个误解。它们绝不是某种线程私有变量。它们是全局可寻址的内存,其生命周期限于线程。任何函数都可以修改它们,来自任何线程,如果给定这些变量的指针。上面的问题本质上是关于如何获取该指针地址的问题。

3
当然,总的想法是你不能这样做。为什么不使用非本地数据结构让每个线程报告其私有值呢? - Bo Persson
据推测,__thread 告诉编译器它可以使用 CPU 寄存器。如果是这样的话,由于硬件限制,直接访问将变得不可能。 - SpliFF
2
@SpliFF,__thread本地变量最终只是普通内存中的位置。您可以获取其地址并将其提供给另一个线程以进行访问。 - edA-qa mort-ora-y
@Bo,这是我的评论,关于缺少start-thread hook。 我没有拦截所有线程创建并注册变量的可能性。同样,我不能为从所属线程对变量的任何读取访问添加函数调用的开销。 - edA-qa mort-ora-y
从技术上讲也不可能。例如在Windows中,它很可能使用TLS线程存储,该存储在操作系统级别不可从另一个线程访问。 - Michael Chourdakis
显示剩余3条评论
5个回答

17

如果您想要非线程本地的线程本地变量,为什么不使用全局变量来代替呢?

重要澄清!

我并不建议您使用单个全局变量来替换线程本地变量。我建议使用单个全局数组或其他适当的值集合来替换一个线程本地变量。

当然,您需要提供同步,但是既然您想要将A线程中修改的值暴露给B线程,那就无法绕过同步。

更新:

GCC关于__thread的文档说:

当对线程本地变量应用取地址运算符时,它会在运行时被计算,并返回该变量当前线程实例的地址。所获得的地址可以被任何线程使用。当线程终止时,该线程中指向线程本地变量的任何指针都将失效。

因此,如果您坚持这种方式,我想从属于它的线程获取线程本地变量的地址是可能的,就在线程被生成后。然后,您可以将该内存位置的指针存储到映射(线程ID => 指针)中,并让其他线程通过这种方式访问变量。这假设您拥有生成线程的代码。

如果您真的大胆尝试,可以尝试挖掘___tls_get_addr的信息(从此PDF开始,该PDF链接在上述GCC文档中)。但是,这种方法与编译器和平台高度相关,并且缺乏文档,因此应该引起任何人的警惕。


6
这并没有回答我的问题。 - edA-qa mort-ora-y
5
@edA-qa mort-ora-y:据我理解,这个问题的意思是“如何用锤子挖一个洞?”。因此,我建议使用更适合这项工作的工具。 - Jon
6
@edA-qa mort-ora-y: 在我看来,那个回答完美地回答了你的问题。TLS的定义是“我不想在我的线程之间共享这个变量”,这需要编译器额外的工作来确保这一属性。建议使用普通的全局变量而不是尝试绕过TLS的建议是有道理的。否则,这就像将盐加入咖啡中,然后将其倒在下水道中,因为你不喜欢咖啡里的盐。 - Damon
3
TLS可以在线程之间共享:它像其他任何可寻址内存一样。我正在寻找一种方法,可以在没有源线程通信的情况下发现这些变量的地址。 - edA-qa mort-ora-y
1
@edA-qa mort-ora-y:我知道线程本地意味着“这个变量的多个副本”,同时也知道原则上没有什么阻止你“公开”一个变量。我已经更新了答案,明确表明了这一点;我仍然认为你在试图逆流而上。 - Jon
显示剩余3条评论

5
我正在寻找相同的事情。根据我的搜索结果,似乎没有人回答了你的问题:假设在Linux(Ubuntu)上使用-m64编译gcc,并且gs段寄存器保持值为0,则段的隐藏部分(保存线性地址)指向特定于线程的本地区域。该区域包含该地址处的地址(64位)。所有线程本地变量都存储在较低的地址中。该地址是native_handle()。因此,要访问线程的本地数据,您应该通过该指针进行操作。
换句话说:(char*)&variable-(char*)myThread.native_handle()+(char*)theOtherThread.native_handle() 以下代码演示了上述内容,假设使用g ++、linux和pthread:
#include <iostream>
#include <thread>
#include <sstream>

thread_local int B=0x11111111,A=0x22222222;

bool shouldContinue=false;

void code(){
    while(!shouldContinue);
    std::stringstream ss;
    ss<<" A:"<<A<<" B:"<<B<<std::endl;
    std::cout<<ss.str();
}

//#define ot(th,variable) 
//(*( (char*)&variable-(char*)(pthread_self())+(char*)(th.native_handle()) ))

int& ot(std::thread& th,int& v){
    auto p=pthread_self();
    intptr_t d=(intptr_t)&v-(intptr_t)p;
    return *(int*)((char*)th.native_handle()+d);
}

int main(int argc, char **argv)
{       

        std::thread th1(code),th2(code),th3(code),th4(code);

        ot(th1,A)=100;ot(th1,B)=110;
        ot(th2,A)=200;ot(th2,B)=210;
        ot(th3,A)=300;ot(th3,B)=310;
        ot(th4,A)=400;ot(th4,B)=410;

        shouldContinue=true;

        th1.join();
        th2.join();
        th3.join();
        th4.join();

    return 0;
}

1
唉...太不可移植了。[如果gcc的std :: thread停止使用pthread作为native_handle呢?或者结构发生变化呢?这可能会在任何下一个gcc更新中发生。] - tower120

3

这是一个老问题了,但既然没有给出答案,为什么不使用具有自己静态注册的类呢?

#include <mutex>
#include <thread>
#include <unordered_map>

struct foo;

static std::unordered_map<std::thread::id, foo*> foos;
static std::mutex foos_mutex;

struct foo
{
    foo()
    {
        std::lock_guard<std::mutex> lk(foos_mutex);
        foos[std::this_thread::get_id()] = this;
    }
};

static thread_local foo tls_foo;


当然,你需要在线程之间进行某种同步以确保线程已经注册了指针,但是你可以从任何你知道线程id的线程中从map中获取它。

我所给出的答案是,似乎无法实现我所要求的内容。虽然有许多其他方法可以做到其他事情,但问题的严格要求似乎无法满足。 - edA-qa mort-ora-y
我想我不明白为什么这不能满足您的要求。您无需知道线程入口点,只需要定义一个结构,在构造时将指针注册到自身上,然后将该结构作为您的__thread变量。除了在线程启动时进行的初始注册外,这样做不会有任何开销。 - Zoltan Dewitt
@edA-qa mort-ora-y :: 这个答案对我非常有用。现在,我猜测OP不满意的原因。OP提到我们想要“在每个线程的开头跟踪这个值”。然而,在这个答案中,对foos的注册甚至可能不会发生,因为有些线程可能永远不会访问tls_foo。线程本地变量只有在被引用时才会初始化。 - cppBeginner

2

很遗憾,我从未能找到一种方法来实现这个。

如果没有某种类型的线程初始化钩子,似乎无法访问该指针(除非使用依赖于平台的ASM黑客攻击)。


0
这基本上可以满足你的需求,如果不行的话可以根据你的要求进行修改。
在Linux上,它使用`pthread_key_create`,而在Windows上则使用`TlsAlloc`。它们都是通过“键”来检索线程本地数据的一种方式。但是,如果你注册了这些键,你就可以在其他线程上访问数据。
EnumerableThreadLocal的思想是在你的线程中执行本地操作,然后在主线程中将结果合并起来。
tbb有一个类似的函数叫做enumerable_thread_specific,关于它的动机可以在https://oneapi-src.github.io/oneTBB/main/tbb_userguide/design_patterns/Divide_and_Conquer.html找到。
下面的代码是在不依赖tbb的情况下模仿tbb代码的尝试。但是,下面的代码的缺点是在Windows上只能使用1088个键。
    template <typename T>
    class EnumerableThreadLocal
    {

#if _WIN32 || _WIN64
        using tls_key_t = DWORD;
        void create_key() { my_key = TlsAlloc(); }
        void destroy_key() { TlsFree(my_key); }
        void set_tls(void *value) { TlsSetValue(my_key, (LPVOID)value); }
        void *get_tls() { return (void *)TlsGetValue(my_key); }
#else
        using tls_key_t = pthread_key_t;
        void create_key() { pthread_key_create(&my_key, nullptr); }
        void destroy_key() { pthread_key_delete(my_key); }
        void set_tls(void *value) const { pthread_setspecific(my_key, value); }
        void *get_tls() const { return pthread_getspecific(my_key); }
#endif
        std::vector<std::pair<std::thread::id, std::unique_ptr<T>>> m_thread_locals;
        std::mutex m_mtx;
        tls_key_t my_key;

        using Factory = std::function<std::unique_ptr<T>()>;
        Factory m_factory;

        static auto DefaultFactory()
        {
            return std::make_unique<T alignas(hardware_constructive_interference_size)>();
        }

    public:

        EnumerableThreadLocal(Factory factory = &DefaultFactory ) : m_factory(factory)
        {
            create_key();
        }

        ~EnumerableThreadLocal()
        {
            destroy_key();
        }

        EnumerableThreadLocal(const EnumerableThreadLocal &other)
        {
            create_key();
            // deep copy the m_thread_locals
            m_thread_locals.reserve(other.m_thread_locals.size());
            for (const auto &pair : other.m_thread_locals)
            {
                m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
            }
        }

        EnumerableThreadLocal &operator=(const EnumerableThreadLocal &other)
        {
            if (this != &other)
            {
                destroy_key();
                create_key();
                m_thread_locals.clear();
                m_thread_locals.reserve(other.m_thread_locals.size());
                for (const auto &pair : other.m_thread_locals)
                {
                    m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
                }
            }
            return *this;
        }

        EnumerableThreadLocal(EnumerableThreadLocal &&other) noexcept
        {
            // deep move
            my_key = other.my_key;
            // deep move the m_thread_locals
            m_thread_locals = std::move(other.m_thread_locals);
            other.my_key = 0;

        }

        EnumerableThreadLocal &operator=(EnumerableThreadLocal &&other) noexcept
        {
            if (this != &other)
            {
                destroy_key();
                my_key = other.my_key;
                m_thread_locals = std::move(other.m_thread_locals);
                other.my_key = 0;
            }
            return *this;
        }

        T *Get ()
        {
            void *v = get_tls();
            if (v)
            {
                return reinterpret_cast<T *>(v);
            }
            else
            {
                const std::scoped_lock l(m_mtx);
                for (const auto &[thread_id, uptr] : m_thread_locals)
                {
                    // This search is necessary for the case if we run out of TLS indicies in customer's process, and we do at least slow lookup
                    if (thread_id == std::this_thread::get_id())
                    {
                        set_tls(reinterpret_cast<void *>(uptr.get()));
                        return uptr.get();
                    }
                }

                m_thread_locals.emplace_back(std::this_thread::get_id(), m_factory());
                T *ptr = m_thread_locals.back().second.get();
                set_tls(reinterpret_cast<void *>(ptr));
                return ptr;
            }
        }

        T const * Get() const
        {
            return const_cast<EnumerableThreadLocal *>(this)->Get();
        }

        T & operator *()
        {
            return *Get();
        }

        T const & operator *() const
        {
            return *Get();
        }

        T * operator ->()
        {
            return Get();
        }

        T const * operator ->() const
        {
            return Get();
        }

        template <typename F>
        void Enumerate(F fn)
        {
            const std::scoped_lock lock(m_mtx);
            for (auto &[thread_id, ptr] : m_thread_locals)
                fn(*ptr);
        }
    };

一个测试套件,展示给你它的工作原理。
#include <thread>
#include <string>
#include "gtest/gtest.h"
#include "EnumerableThreadLocal.hpp"

TEST(EnumerableThreadLocal, BasicTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<std::string> expected;
    tls.Enumerate([&](std::string &s)
                  { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }


}

// Create a non copyable type, non moveable type
struct NonCopyable
{
    int i=0;
    NonCopyable() = default;
    NonCopyable(const NonCopyable &) = delete;
    NonCopyable(NonCopyable &&) = delete;
    NonCopyable &operator=(const NonCopyable &) = delete;
    NonCopyable &operator=(NonCopyable &&) = delete;
};

// A test to see if we can insert non moveable/ non copyable types to the tls
TEST(EnumerableThreadLocal, NonCopyableTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NonCopyable> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i=i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<int> expected;
    tls.Enumerate([&](NonCopyable &s)
                  { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }
}

const int N = 10;
v31::EnumerableThreadLocal<std::string> CreateFixture()
{
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    return tls;
}

void CheckFixtureCopy(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }
}

void CheckFixtureEmpty(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    ASSERT_EQ(expected.size(), 0);
}

/// Test for copy construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Copy)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}

/// Test for move construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Move)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

/// Test for copy assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, CopyAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}   

/// Test for move assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, MoveAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

//class with no default constructor
struct NoDefaultConstructor
{
    int i;
    NoDefaultConstructor(int i) : i(i) {}
};

// Test for using objects with no default constructor
TEST(EnumerableThreadLocal, NoDefaultConstructor)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NoDefaultConstructor> tls([]{return std::make_unique<NoDefaultConstructor>(0);});

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i = i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    // enumerate and sort and verify
    std::vector<int> expected;  
    tls.Enumerate([&](NoDefaultConstructor &s)
                    { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }

}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接