Skip to content

Instantly share code, notes, and snippets.

@Rexicon226
Last active October 23, 2024 06:43
Show Gist options
  • Save Rexicon226/0a76e0979629e8b0e30251a7d4174d0d to your computer and use it in GitHub Desktop.
Save Rexicon226/0a76e0979629e8b0e30251a7d4174d0d to your computer and use it in GitHub Desktop.
// https://github.com/crossbeam-rs/crossbeam/blob/abf24e1f31a76ef2590688d0c2bb55f82c7410a9/crossbeam-queue/src/seg_queue.rs
const std = @import("std");
const Atomic = std.atomic.Value;
const Allocator = std.mem.Allocator;
pub fn Channel(T: type) type {
return struct {
head: Position,
tail: Position,
allocator: Allocator,
const Self = @This();
const BLOCK_CAP = 31;
const SHIFT = 1;
const LAP = 32;
const WRITTEN_TO: usize = 0b01;
const READ_FROM: usize = 0b10;
const DESTROYED: usize = 0b100;
const HAS_NEXT: usize = 0b01;
const Position = struct {
index: Atomic(usize),
block: Atomic(?*Buffer),
fn deinit(pos: *Position, allocator: Allocator) void {
if (pos.block.load(.monotonic)) |block| {
block.deinit(allocator);
allocator.destroy(block);
}
}
};
const Buffer = struct {
next: Atomic(?*Buffer),
slots: [BLOCK_CAP]Slot,
fn create(allocator: Allocator) !*Buffer {
const new = try allocator.create(Buffer);
@memset(&new.slots, Slot.uninit);
new.next = Atomic(?*Buffer).init(null);
return new;
}
fn destroy(block: *Buffer, start: usize, allocator: Allocator) void {
for (start..BLOCK_CAP - 1) |i| {
const slot = &block.slots[i];
if (slot.state.load(.acquire) & READ_FROM == 0 and
slot.state.fetchOr(DESTROYED, .acq_rel) & READ_FROM == 0)
{
return;
}
}
allocator.destroy(block);
}
fn deinit(block: *Buffer, allocator: Allocator) void {
if (block.next.load(.monotonic)) |n| {
n.deinit(allocator);
allocator.destroy(n);
}
}
};
const Slot = struct {
value: T,
state: Atomic(usize),
const uninit: Slot = .{
.value = undefined,
.state = Atomic(usize).init(0),
};
};
pub fn init(allocator: Allocator, initial_capacity: usize) !Self {
_ = initial_capacity; // TODO: do something with this
const first_block = try Buffer.create(allocator);
const first_position: Position = .{
.index = Atomic(usize).init(0),
.block = Atomic(?*Buffer).init(first_block),
};
return .{
.head = first_position,
.tail = first_position,
.allocator = allocator,
};
}
pub fn create(allocator: Allocator, initial_capacity: usize) !*Self {
const channel = try allocator.create(Self);
channel.* = try Self.init(allocator, initial_capacity);
return channel;
}
pub fn send(channel: *Self, value: T) !void {
var backoff: Backoff = .{};
var tail = channel.tail.index.load(.acquire);
var block = channel.tail.block.load(.acquire);
var next_block: ?*Buffer = null;
while (true) {
const offset = (tail >> SHIFT) % LAP;
if (offset == BLOCK_CAP) {
// Another block has incremented the tail index before us,
// we need to wait for the next block to be installed.
backoff.snooze();
tail = channel.tail.index.load(.acquire);
block = channel.tail.block.load(.acquire);
continue;
}
if (offset + 1 == BLOCK_CAP and next_block == null) {
next_block = try Buffer.create(channel.allocator);
}
const new_tail = tail + (1 << SHIFT);
// Try to increment the tail index by one to block all future producers
// until we install the next_block and next_index.
if (channel.tail.index.cmpxchgWeak(tail, new_tail, .seq_cst, .acquire)) |t| {
// We failed the CAS, another thread has installed a new tail index.
tail = t;
block = channel.tail.block.load(.acquire);
backoff.spin();
} else {
// We won the race, now we install the next_block and next_index for other threads
// to see and unblock.
if (offset + 1 == BLOCK_CAP) {
// We're now one over the block cap and the next slot we write to
// will be inside of the next block. Wrap the offset around the block,
// and shift to get the next index.
const next_index = new_tail +% (1 << SHIFT);
channel.tail.block.store(next_block, .release);
channel.tail.index.store(next_index, .release);
block.?.next.store(next_block, .release);
} else if (next_block) |b| {
// When you win the CAS the time other threads snooze is from the CAS to the tail.index store above.
// Moving allocation after the CAS increases the amount of time other threads must snooze.
// Moving the allocation before the CAS lowers it, but you need to handle the install
// failures by de-allocating the next_block.
channel.allocator.destroy(b);
}
const slot = &block.?.slots[offset];
slot.value = value;
// Release the exclusive lock on the slot's value, which allows a consumer
// to read the data we've just assigned.
_ = slot.state.fetchOr(WRITTEN_TO, .release);
return;
}
}
}
pub fn receive(channel: *Self) ?T {
var backoff: Backoff = .{};
var head = channel.head.index.load(.acquire);
var block = channel.head.block.load(.acquire);
while (true) {
// Shift away the meta-data bits and get the index into whichever block we're in.
const offset = (head >> SHIFT) % LAP;
// This means another thread has begun the process of installing a new head block and index,
// we just need to wait until that's done.
if (offset == BLOCK_CAP) {
backoff.snooze();
head = channel.head.index.load(.acquire);
block = channel.head.block.load(.acquire);
continue;
}
// After we consume this, this will be our next head.
var new_head = head + (1 << SHIFT);
// A bit confusing, but this checks if the current head *doesn't* have a next block linked.
// It's encoded as a bit in order to be able to tell without dereferencing it.
if (new_head & HAS_NEXT == 0) {
// A rare usecase for fence :P, we need to create a barrier before anything else modifying
// the index. This is just easier than creating an acquire-release pair.
channel.tail.index.fence(.seq_cst);
const tail = channel.tail.index.load(.monotonic);
// If the indicies are the same, the channel is empty and there's nothing to receive.
if (head >> SHIFT == tail >> SHIFT) {
return null;
}
// The head index must always be less than or equal to the tail index.
// Using this invariance, we can prove that if the head is in a different block than
// the tail, it *must* be ahead of it and in a "next" block. Hence we set the "HAS_NEXT"
// bit in the index.
if ((head >> SHIFT) / LAP != (tail >> SHIFT) / LAP) {
new_head |= HAS_NEXT;
}
}
// Try to install a new head index.
if (channel.head.index.cmpxchgWeak(head, new_head, .seq_cst, .acquire)) |h| {
// We lost the install race against something, the new head index is acquired,
// and we update the block as it could have changed due to a new block being installed.
head = h;
block = channel.head.block.load(.acquire);
backoff.spin();
} else {
// There is a consumer on the other end that should be installing the next block right now.
if (offset + 1 == BLOCK_CAP) {
// Wait until it installs the next block and update the references.
const next = while (true) {
backoff.snooze();
const next = block.?.next.load(.acquire);
if (next != null) break next.?;
};
var next_index = (new_head & ~HAS_NEXT) +% (1 << SHIFT);
if (next.next.load(.monotonic) != null) {
next_index |= HAS_NEXT;
}
channel.head.block.store(next, .release);
channel.head.index.store(next_index, .release);
}
// Now we should have a stable reference to a slot. Loop if there's a producer
// currently writing to this slot.
const slot = &block.?.slots[offset];
while (slot.state.load(.acquire) & WRITTEN_TO == 0) {
backoff.snooze();
}
const value = slot.value;
// If this is the last block, we can just destroy it.
if (offset + 1 == BLOCK_CAP) {
block.?.destroy(0, channel.allocator);
} else
// Set the slot as READ_FROM, and if DESTROYED was set, destroy the block.
if (slot.state.fetchOr(READ_FROM, .acq_rel) & DESTROYED != 0) {
block.?.destroy(offset + 1, channel.allocator);
}
return value;
}
}
}
pub fn len(channel: *Self) usize {
while (true) {
var tail = channel.tail.index.load(.seq_cst);
var head = channel.head.index.load(.seq_cst);
// Make sure `tail` wasn't modified while we were loading `head`.
if (channel.tail.index.load(.seq_cst) == tail) {
// Shift out the bottom bit, which is used to indicate whether
// there is a next link in the block.
tail &= ~((@as(usize, 1) << SHIFT) - 1);
head &= ~((@as(usize, 1) << SHIFT) - 1);
// We're waiting for another thread to install the next_block
// and next_index, so we "mock" increment our tail as if it was installed.
if ((tail >> SHIFT) % (LAP - 1) == (LAP - 1)) {
tail +%= (1 << SHIFT);
}
if ((head >> SHIFT) % (LAP - 1) == (LAP - 1)) {
head +%= (1 << SHIFT);
}
// Calculate on which block link we're on. Between 0-31 is block 1, 32-63 is block 2, etc.
const lap = (head >> SHIFT) / LAP;
// Rotates the indices to fall into the first slot.
// (lap * LAP) is the first index of the block we're in.
tail -%= (lap * LAP) << SHIFT;
head -%= (lap * LAP) << SHIFT;
// Remove the lower bits.
tail >>= SHIFT;
head >>= SHIFT;
// Return the difference minus the number of blocks between tail and head.
return tail - head - tail / LAP;
}
}
}
pub fn isEmpty(channel: *Self) bool {
const head = channel.head.index.load(.seq_cst);
const tail = channel.tail.index.load(.seq_cst);
// The channel is empty if the indices are pointing at the same slot.
return (head >> SHIFT) == (tail >> SHIFT);
}
pub fn deinit(channel: *Self) void {
var head = channel.head.index.raw;
var tail = channel.tail.index.raw;
var block = channel.head.block.raw;
head &= ~((@as(usize, 1) << SHIFT) - 1);
tail &= ~((@as(usize, 1) << SHIFT) - 1);
while (head != tail) {
const offset = (head >> SHIFT) % LAP;
if (offset >= BLOCK_CAP) {
const next = block.?.next.raw;
channel.allocator.destroy(block.?);
block = next;
}
head +%= (1 << SHIFT);
}
if (block) |b| {
channel.allocator.destroy(b);
}
}
pub fn close(channel: *Self) void {
_ = channel;
}
};
}
const expect = std.testing.expect;
const assert = std.debug.assert;
test "smoke" {
var ch = BetterChannel(u32).init(std.testing.allocator);
defer ch.deinit();
try ch.push(7);
try expect(ch.pop() == 7);
try ch.push(8);
try expect(ch.pop() == 8);
try expect(ch.pop() == null);
}
test "len_empty_full" {
var ch = BetterChannel(u32).init(std.testing.allocator);
defer ch.deinit();
try expect(ch.len() == 0);
try expect(ch.isEmpty());
try ch.push(0);
try expect(ch.len() == 1);
try expect(!ch.isEmpty());
_ = ch.pop().?;
try expect(ch.len() == 0);
try expect(ch.isEmpty());
}
test "len" {
var ch = BetterChannel(u64).init(std.testing.allocator);
defer ch.deinit();
try expect(ch.len() == 0);
for (0..50) |i| {
try ch.push(i);
try expect(ch.len() == i + 1);
}
for (0..50) |i| {
_ = ch.pop().?;
try expect(ch.len() == 50 - i - 1);
}
try expect(ch.len() == 0);
}
test "spsc" {
const COUNT = 100;
const S = struct {
fn producer(ch: *BetterChannel(u64)) !void {
for (0..COUNT) |i| {
try ch.push(i);
}
}
fn consumer(ch: *BetterChannel(u64)) void {
for (0..COUNT) |i| {
while (true) {
if (ch.pop()) |x| {
assert(x == i);
break;
}
}
}
}
};
var ch = BetterChannel(u64).init(std.testing.allocator);
defer ch.deinit();
const consumer = try std.Thread.spawn(.{}, S.consumer, .{&ch});
const producer = try std.Thread.spawn(.{}, S.producer, .{&ch});
consumer.join();
producer.join();
}
test "mpmc" {
const COUNT = 100;
const THREADS = 4;
const S = struct {
fn producer(ch: *BetterChannel(u64)) !void {
for (0..COUNT) |i| {
try ch.push(i);
}
}
fn consumer(ch: *BetterChannel(u64), v: *[COUNT]AtomicValue(usize)) void {
for (0..COUNT) |_| {
const n = while (true) {
if (ch.pop()) |x| break x;
};
_ = v[n].fetchAdd(1, .seq_cst);
}
}
};
var v: [COUNT]AtomicValue(usize) = .{AtomicValue(usize).init(0)} ** COUNT;
var ch = BetterChannel(u64).init(std.testing.allocator);
defer ch.deinit();
var c_threads: [THREADS]std.Thread = undefined;
var p_threads: [THREADS]std.Thread = undefined;
for (&c_threads) |*c_thread| {
c_thread.* = try std.Thread.spawn(.{}, S.consumer, .{ &ch, &v });
}
for (&p_threads) |*p_thread| {
p_thread.* = try std.Thread.spawn(.{}, S.producer, .{&ch});
}
for (c_threads, p_threads) |c_thread, p_thread| {
c_thread.join();
p_thread.join();
}
for (v) |c| try expect(c.load(.seq_cst) == THREADS);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment