我有一个原子计数器 (std::atomic<uint32_t> count
),它可以给多个线程分配顺序递增的值。
uint32_t my_val = ++count;
在获取
my_val
之前,我希望确保增量不会溢出(即返回到0)。if (count == std::numeric_limits<uint32_t>::max())
throw std::runtime_error("count overflow");
我认为这是一个天真的检查,因为如果在两个线程都增加计数器之前执行检查,第二个增加的线程将会得到0。
if (count == std::numeric_limits<uint32_t>::max()) // if 2 threads execute this
throw std::runtime_error("count overflow");
uint32_t my_val = ++count; // before either gets here - possible overflow
因此,我猜我需要使用CAS操作来确保当我递增计数器时,确实防止了可能的溢出。
所以我的问题是:
- 我的实现是否正确?
- 它是否尽可能高效(具体而言,我是否需要两次检查
max
)?
以下是我的代码(带有工作示例):
#include <iostream>
#include <atomic>
#include <limits>
#include <stdexcept>
#include <thread>
std::atomic<uint16_t> count;
uint16_t get_val() // called by multiple threads
{
uint16_t my_val;
do
{
my_val = count;
// make sure I get the next value
if (count.compare_exchange_strong(my_val, my_val + 1))
{
// if I got the next value, make sure we don't overflow
if (my_val == std::numeric_limits<uint16_t>::max())
{
count = std::numeric_limits<uint16_t>::max() - 1;
throw std::runtime_error("count overflow");
}
break;
}
// if I didn't then check if there are still numbers available
if (my_val == std::numeric_limits<uint16_t>::max())
{
count = std::numeric_limits<uint16_t>::max() - 1;
throw std::runtime_error("count overflow");
}
// there are still numbers available, so try again
}
while (1);
return my_val + 1;
}
void run()
try
{
while (1)
{
if (get_val() == 0)
exit(1);
}
}
catch(const std::runtime_error& e)
{
// overflow
}
int main()
{
while (1)
{
count = 1;
std::thread a(run);
std::thread b(run);
std::thread c(run);
std::thread d(run);
a.join();
b.join();
c.join();
d.join();
std::cout << ".";
}
return 0;
}