这个队列实现上可能存在ABA问题。具体来说,如果一个节点在某一时刻被弹出,然后又有一个新节点插入并且其指针恰好和该节点相同,那么后续的操作会将该新节点误认为是之前弹出的节点,从而导致问题。
解决方法是使用带有版本号的指针,在CAS操作时同时比较版本号。以下给出改进后的代码:
#include
template
LockFreeQueue() : head_(new Node), tail_(head_.load()) {}
~LockFreeQueue()
{
while (Node *old_head = head_.load())
{
Node *next = old_head->next_;
if (head_.compare_exchange_weak(old_head, next))
{
delete old_head;
break;
}
}
}
void push(const T &value)
{
Node *new_node = new Node(value);
while (true)
{
Tail tail = tail_.load();
Node *tail_ptr = tail.ptr;
uint64_t tail_version = tail.version;
Node *next = tail_ptr->next_.load(std::memory_order_relaxed);
if (!next)
{
if (tail_ptr->next_.compare_exchange_weak(next, new_node,
std::memory_order_release,
std::memory_order_relaxed))
{
Tail new_tail{new_node, tail_version + 1};
tail_.compare_exchange_weak(tail, new_tail,
std::memory_order_release,
std::memory_order_relaxed);
size_.fetch_add(1, std::memory_order_release);
return;
}
}
else
{
Tail new_tail{next, tail_version + 1};
tail_.compare_exchange_weak(tail, new_tail,
std::memory_order_release,
std::memory_order_relaxed);
}
}
}
bool try_pop(T &value)
{
while (true)
{
Head head = head_.load();
Tail tail = tail_.load();
Node *head_ptr = head.ptr;
uint64_t head_version = head.version;
uint64_t tail_version = tail.version;
Node *next = head_ptr->next_.load(std::memory_order_relaxed);
if (head.ptr == tail.ptr)
{
if (!next)
{
return false;
}
Tail new_tail{next, tail_version + 1};
tail_.compare_exchange_weak(tail, new_tail,
std::memory_order_release,
std::memory_order_relaxed);
}
else
{
value = next->value_;
Head new_head{next, head_version + 1};
if (head_.compare_exchange_weak(head, new_head,
std::memory_order_release,
std::memory_order_relaxed))
{
size_.fetch_sub(1, std::memory_order_release);
delete head_ptr;
return true;
}
}
}
}
bool empty() const
{
return size_.load(std::memory_order_acquire) == 0;
}
private:
struct Node
{
T value_;
std::atomic<Node *> next_;
Node() : value_(T()), next_(nullptr) {}
explicit Node(const T &value) : value_(value), next_(nullptr) {}
};
struct Head
{
Node *ptr;
uint64_t version;
};
struct Tail
{
Node *ptr;
uint64_t version;
};
std::atomic<Head> head_;
std::atomic<Tail> tail_;
std::atomic<size_t> size_{0};
};