防止std::atomic溢出

9

我有一个原子计数器 (std::atomic<uint32_t> count),它可以给多个线程分配顺序递增的值。

uint32_t my_val = ++count;

在获取my_val之前,我希望确保增量不会溢出(即返回到0)。
if (count == std::numeric_limits<uint32_t>::max())
    throw std::runtime_error("count overflow");

我认为这是一个天真的检查,因为如果在两个线程都增加计数器之前执行检查,第二个增加的线程将会得到0。
if (count == std::numeric_limits<uint32_t>::max()) // if 2 threads execute this
    throw std::runtime_error("count overflow");
uint32_t my_val = ++count;       // before either gets here - possible overflow

因此,我猜我需要使用CAS操作来确保当我递增计数器时,确实防止了可能的溢出。

所以我的问题是:

  • 我的实现是否正确?
  • 它是否尽可能高效(具体而言,我是否需要两次检查max)?

以下是我的代码(带有工作示例):

#include <iostream>
#include <atomic>
#include <limits>
#include <stdexcept>
#include <thread>

std::atomic<uint16_t> count;

uint16_t get_val() // called by multiple threads
{
    uint16_t my_val;
    do
    {
        my_val = count;

        // make sure I get the next value

        if (count.compare_exchange_strong(my_val, my_val + 1))
        {
            // if I got the next value, make sure we don't overflow

            if (my_val == std::numeric_limits<uint16_t>::max())
            {
                count = std::numeric_limits<uint16_t>::max() - 1;
                throw std::runtime_error("count overflow");
            }
            break;
        }

        // if I didn't then check if there are still numbers available

        if (my_val == std::numeric_limits<uint16_t>::max())
        {
            count = std::numeric_limits<uint16_t>::max() - 1;
            throw std::runtime_error("count overflow");
        }

        // there are still numbers available, so try again
    }
    while (1);
    return my_val + 1;
}

void run()
try
{
    while (1)
    {
        if (get_val() == 0)
            exit(1);
    }

}
catch(const std::runtime_error& e)
{
    // overflow
}

int main()
{
    while (1)
    {
        count = 1;
        std::thread a(run);
        std::thread b(run);
        std::thread c(run);
        std::thread d(run);
        a.join();
        b.join();
        c.join();
        d.join();
        std::cout << ".";
    }
    return 0;
}

1
uint64_t,问题解决了吗? :) - NoSenseEtAl
3个回答

8
是的,你需要使用CAS操作。
std::atomic<uint16_t> g_count;

uint16_t get_next() {
   uint16_t new_val = 0;
   do {
      uint16_t cur_val = g_count;                                            // 1
      if (cur_val == std::numeric_limits<uint16_t>::max()) {                 // 2
          throw std::runtime_error("count overflow");
      }
      new_val = cur_val + 1;                                                 // 3
   } while(!std::atomic_compare_exchange_weak(&g_count, &cur_val, new_val)); // 4

   return new_val;
}

思路如下:一旦 g_count == std::numeric_limits<uint16_t>::max()get_next() 函数将始终抛出异常。
步骤:
  1. 获取计数器的当前值
  2. 如果它是最大值,则抛出异常(没有可用数字了)
  3. 将新值作为当前值的增量获取
  4. 尝试原子地设置新值。如果我们未能设置它(它已被另一个线程完成),则重试。

2
如果效率是一个重要的问题,那么我建议在检查时不要太严格。我猜在正常使用情况下,溢出不会成为问题,但你真的需要完整的65K范围吗(你的示例使用uint16)?
如果你假设你运行的线程数量有一个最大值,那么这将更容易。这是一个合理的限制,因为没有程序可以拥有无限数量的并发性。所以,如果你有N个线程,你可以简单地将你的溢出限制减少到"65K - N"。为了比较是否溢出,你不需要使用CAS。
uint16_t current = count.load(std::memory_order_relaxed);
if( current >= (std::numeric_limits<uint16_t>::max() - num_threads - 1) )
    throw std::runtime_error("count overflow");
count.fetch_add(1,std::memory_order_relaxed);

这会创建一个软溢出的情况。如果两个线程同时到达这里,它们都有可能通过,但这没关系,因为计数变量本身从未溢出。任何以后到达此处的内容都将逻辑性地溢出(直到计数再次减少)。


我只使用 uint16_t ,以便更早地发生潜在的溢出。我喜欢你的想法,但线程数现在是已知的,我最不想做的就是选择一个魔术上限,然后在将来某个时候超过它。 - Steve Lorimer
@lori,可能不知道具体数字,但必须有一些合理的上限。您可以指定1000个线程,甚至可以是10K。或者使用uint_32,然后只需指定一百万即可。 - edA-qa mort-ora-y

1

在我看来,仍存在竞争条件,其中count会短暂地被设置为0,以至于另一个线程将看到该值为0。

假设count位于std::numeric_limits<uint16_t>::max(),并且有两个线程尝试获取递增的值。当线程1执行count.compare_exchange_strong(my_val, my_val + 1)时,在此时将count设置为0,如果线程2在线程1有机会将count恢复为max()之前调用和完成get_val(),则线程2将看到count为0。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接