Skip to content

Commit

Permalink
coro: refactor quite a bit
Browse files Browse the repository at this point in the history
Wakeup and yields must be matched with expected state.
Reaps and deinits should now be memory safe.
And some other improvements.
  • Loading branch information
Cloudef committed Jun 18, 2024
1 parent 6894083 commit 2e6870b
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 96 deletions.
8 changes: 6 additions & 2 deletions src/aio.zig
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub const Dynamic = struct {
const ti = @typeInfo(@TypeOf(operations));
if (comptime ti == .Struct and ti.Struct.is_tuple) {
return self.io.queue(operations.len, &struct { ops: @TypeOf(operations) }{ .ops = operations });
} else if (comptime ti == .Array) {
return self.io.queue(operations.len, &struct { ops: @TypeOf(operations) }{ .ops = operations });
} else {
return self.io.queue(1, &struct { ops: @TypeOf(.{operations}) }{ .ops = .{operations} });
}
Expand All @@ -78,8 +80,10 @@ pub inline fn batch(operations: anytype) ImmediateError!CompletionResult {
const ti = @typeInfo(@TypeOf(operations));
if (comptime ti == .Struct and ti.Struct.is_tuple) {
return IO.immediate(operations.len, &struct { ops: @TypeOf(operations) }{ .ops = operations });
} else if (comptime ti == .Array) {
return IO.immediate(operations.len, &struct { ops: @TypeOf(operations) }{ .ops = operations });
} else {
@compileError("expected a tuple of operations");
@compileError("expected a tuple or array of operations");
}
}

Expand All @@ -93,7 +97,7 @@ pub inline fn multi(operations: anytype) (ImmediateError || error{SomeOperationF
/// Completes a single operation immediately, blocks until complete
pub inline fn single(operation: anytype) (ImmediateError || OperationError)!void {
var op: @TypeOf(operation) = operation;
var err: @TypeOf(op.out_error.?.*) = error.Success;
var err: @TypeOf(operation).Error = error.Success;
op.out_error = &err;
_ = try batch(.{op});
if (err != error.Success) return err;
Expand Down
204 changes: 110 additions & 94 deletions src/coro.zig
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,14 @@ pub const io = struct {

try task.io.queue(work.ops);
task.io_counter = operations.len;
task.status = .doing_io;
debug("yielding for io: {}", .{task});
Fiber.yield();
privateYield(.io);

if (task.io_counter > 0) {
// wakeup() was called, try cancel the io
inline for (&state) |*s| try task.io.queue(aio.Cancel{ .id = s.id });
task.status = .cancelling_io;
debug("yielding for io cancellation: {}", .{task});
Fiber.yield();
// woken up for io cancelation
var cancels: [operations.len]aio.Cancel = undefined;
inline for (&cancels, &state) |*op, *s| op.* = .{ .id = s.id };
try task.io.queue(cancels);
privateYield(.io_cancel);
}

var num_errors: u16 = 0;
Expand All @@ -85,106 +83,103 @@ pub const io = struct {
}

/// Completes a list of operations immediately, blocks until complete
/// The IO operations can be cancelled by calling `wakeup`
/// The IO operations can be cancelled by calling `wakeupFromIo`, or doing `aio.Cancel`
/// Returns `error.SomeOperationFailed` if any operation failed
pub inline fn multi(operations: anytype) (aio.QueueError || error{SomeOperationFailed})!void {
if (try batch(operations) > 0) return error.SomeOperationFailed;
}

/// Completes a single operation immediately, blocks the coroutine until complete
/// The IO operation can be cancelled by calling `wakeup`
/// TODO: combine this and multi to avoid differences/bugs in implementation
/// The IO operation can be cancelled by calling `wakeupFromIo`, or doing `aio.Cancel`
pub fn single(operation: anytype) (aio.QueueError || aio.OperationError)!void {
if (Fiber.current()) |fiber| {
var task: *Scheduler.TaskState = @ptrFromInt(fiber.getUserDataPtr().*);
var op: @TypeOf(operation) = operation;
var err: @TypeOf(operation).Error = error.Success;
op.out_error = &err;
_ = try batch(.{op});
if (err != error.Success) return err;
}
};

var op: @TypeOf(operation) = operation;
var err: @TypeOf(op.out_error.?.*) = error.Success;
var id: aio.Id = undefined;
op.counter = .{ .dec = &task.io_counter };
var old_id: ?*aio.Id = null;
if (@hasDecl(@TypeOf(op), "out_id")) {
old_id = op.out_id;
op.out_id = &id;
}
op.out_error = &err;
try task.io.queue(op);
task.io_counter = 1;
task.status = .doing_io;
debug("yielding for io: {}", .{task});
Fiber.yield();
/// Yields current task, can only be called from inside a task
pub inline fn yield(state: anytype) void {
privateYield(@enumFromInt(std.meta.fields(YieldState).len + @intFromEnum(state)));
}

if (task.io_counter > 0) {
// wakeup() was called, try cancel the io
try task.io.queue(aio.Cancel{ .id = id });
task.status = .cancelling_io;
debug("yielding for io cancellation: {}", .{task});
Fiber.yield();
}
/// Wakeups a task from a yielded state, no-op if `state` does not match the current yielding state
pub inline fn wakeupFromState(task: Scheduler.Task, state: anytype) void {
const node: *Scheduler.TaskNode = @ptrCast(task);
if (node.data.marked_for_reap) return;
privateWakeup(&node.data, @enumFromInt(std.meta.fields(YieldState).len + @intFromEnum(state)));
}

if (old_id) |p| p.* = id;
if (err != error.Success) return err;
} else {
unreachable; // this io function is only meant to be used in coroutines!
}
}
/// Wakeups a task from IO by canceling the current IO operations for that task
pub inline fn wakeupFromIo(task: Scheduler.Task) void {
const node: *Scheduler.TaskNode = @ptrCast(task);
if (node.data.marked_for_reap) return;
privateWakeup(&node.data, .io);
}

/// Wakeups a task regardless of the current yielding state
pub inline fn wakeup(task: Scheduler.Task) void {
const node: *Scheduler.TaskNode = @ptrCast(task);
if (node.data.marked_for_reap) return;
privateWakeup(&node.data, node.data.yield_state);
}

const YieldState = enum(u8) {
not_yielding,
io,
io_cancel,
_, // fields after are reserved for custom use
};

/// Yields current task, can only be called from inside a task
pub inline fn yield() void {
inline fn privateYield(state: YieldState) void {
if (Fiber.current()) |fiber| {
var task: *Scheduler.TaskState = @ptrFromInt(fiber.getUserDataPtr().*);
if (task.status == .dead) unreachable; // race condition
if (task.status == .yield) unreachable; // race condition
task.status = .yield;
std.debug.assert(task.yield_state == .not_yielding);
task.yield_state = state;
debug("yielding: {}", .{task});
Fiber.yield();
} else {
unreachable; // yield is only meant to be used in coroutines!
}
}

/// Wakeups a task by either cancelling the io its doing or switching back to it from yielded state
pub inline fn wakeup(task: Scheduler.Task) void {
const node: *Scheduler.TaskNode = @ptrCast(task);
if (node.data.status == .dead) unreachable; // race condition
if (node.data.status == .running) return; // already awake
if (node.data.status == .cancelling_io) return; // can't wake up when cancelling
debug("waking up from yield: {}", .{node.data});
node.data.status = .running;
node.data.fiber.switchTo();
inline fn privateWakeup(task: *Scheduler.TaskState, state: YieldState) void {
if (task.yield_state != state) return;
debug("waking up from yield: {}", .{task});
task.yield_state = .not_yielding;
task.fiber.switchTo();
}

/// Runtime for asynchronous IO tasks
pub const Scheduler = struct {
allocator: std.mem.Allocator,
io: aio.Dynamic,
tasks: std.DoublyLinkedList(TaskState) = .{},
num_dead: usize = 0,
pending_for_reap: bool = false,

const TaskState = struct {
fiber: *Fiber,
status: enum {
running,
doing_io,
cancelling_io,
yield,
dead,
} = .running,
stack: ?Fiber.Stack = null,
marked_for_reap: bool = false,
io: *aio.Dynamic,
io_counter: u16 = 0,
yield_state: YieldState = .not_yielding,

pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
if (self.status == .doing_io) {
try writer.print("{x}: {s}, {} ops left", .{ @intFromPtr(self.fiber), @tagName(self.status), self.io_counter });
if (self.io_counter > 0) {
try writer.print("{x}: {s}, {} ops left", .{ @intFromPtr(self.fiber), @tagName(self.yield_state), self.io_counter });
} else {
try writer.print("{x}: {s}", .{ @intFromPtr(self.fiber), @tagName(self.status) });
try writer.print("{x}: {s}", .{ @intFromPtr(self.fiber), @tagName(self.yield_state) });
}
}

fn deinit(self: *@This(), allocator: std.mem.Allocator) void {
if (Fiber.current()) |_| unreachable; // do not call deinit from a task
// we can only safely deinit the task when it is not doing IO
// otherwise for example io_uring might write to invalid memory address
std.debug.assert(self.yield_state != .io and self.yield_state != .io_cancel);
if (self.stack) |stack| allocator.free(stack);
self.* = undefined;
}
Expand All @@ -208,26 +203,46 @@ pub const Scheduler = struct {

pub fn reapAll(self: *@This()) void {
if (Fiber.current()) |_| unreachable; // do not call reapAll from a task
while (self.tasks.pop()) |node| {
debug("reaping: {}", .{node.data});
node.data.deinit(self.allocator);
self.allocator.destroy(node);
var maybe_node = self.tasks.first;
while (maybe_node) |node| {
node.data.marked_for_reap = true;
maybe_node = node.next;
}
self.pending_for_reap = true;
}

pub fn reap(self: *@This(), task: Task) void {
fn privateReap(self: *@This(), node: *TaskNode) bool {
if (Fiber.current()) |_| unreachable; // do not call reap from a task
const node: *TaskNode = @ptrCast(task);
if (node.data.yield_state == .io or node.data.yield_state == .io_cancel) {
debug("task is pending on io, reaping later: {}", .{node.data});
if (node.data.yield_state == .io) privateWakeup(&node.data, .io); // cancel io
node.data.marked_for_reap = true;
self.pending_for_reap = true;
return false; // still pending
}
debug("reaping: {}", .{node.data});
self.tasks.remove(node);
node.data.deinit(self.allocator);
self.allocator.destroy(node);
return true;
}

pub fn reap(self: *@This(), task: Task) void {
const node: *TaskNode = @ptrCast(task);
_ = self.privateReap(node);
}

pub fn deinit(self: *@This()) void {
if (Fiber.current()) |_| unreachable; // do not call deinit from a task
self.reapAll();
// destroy io backend first to make sure we can destroy the tasks safely
self.io.deinit(self.allocator);
while (self.tasks.pop()) |node| {
// modify the yield state to avoid state consistency assert in deinit
// it's okay to deinit now since the io backend is dead
node.data.yield_state = .not_yielding;
node.data.deinit(self.allocator);
self.allocator.destroy(node);
}
self.* = undefined;
}

Expand All @@ -238,8 +253,8 @@ pub const Scheduler = struct {
@call(.auto, func, args);
}
var task: *Scheduler.TaskState = @ptrFromInt(Fiber.current().?.getUserDataPtr().*);
task.status = .dead;
self.num_dead += 1;
task.marked_for_reap = true;
self.pending_for_reap = true;
debug("finished: {}", .{task});
}

Expand Down Expand Up @@ -276,41 +291,42 @@ pub const Scheduler = struct {
return node;
}

/// Processes pending IO and reaps dead tasks
pub fn tick(self: *@This(), mode: aio.Dynamic.CompletionMode) !void {
if (Fiber.current()) |_| unreachable; // do not call tick from a task
fn tickIo(self: *@This(), mode: aio.Dynamic.CompletionMode) !void {
const res = try self.io.complete(mode);
if (res.num_completed > 0) {
var maybe_node = self.tasks.first;
while (maybe_node) |node| {
const next = node.next;
switch (node.data.status) {
.running, .dead, .yield => {},
.doing_io, .cancelling_io => if (node.data.io_counter == 0) {
debug("waking up from io: {}", .{node.data});
node.data.status = .running;
node.data.fiber.switchTo();
switch (node.data.yield_state) {
.io, .io_cancel => if (node.data.io_counter == 0) {
privateWakeup(&node.data, node.data.yield_state);
},
else => {},
}
maybe_node = next;
}
}
while (self.num_dead > 0) {
}

/// Processes pending IO and reaps dead tasks
pub fn tick(self: *@This(), mode: aio.Dynamic.CompletionMode) !void {
if (Fiber.current()) |_| unreachable; // do not call tick from a task
try self.tickIo(mode);
if (self.pending_for_reap) {
var num_unreaped: usize = 0;
var maybe_node = self.tasks.first;
while (maybe_node) |node| {
const next = node.next;
switch (node.data.status) {
.running, .doing_io, .cancelling_io, .yield => {},
.dead => {
debug("reaping: {}", .{node.data});
node.data.deinit(self.allocator);
self.tasks.remove(node);
self.allocator.destroy(node);
self.num_dead -= 1;
},
if (node.data.marked_for_reap) {
if (!self.privateReap(node)) {
num_unreaped += 1;
}
}
maybe_node = next;
}
if (num_unreaped == 0) {
self.pending_for_reap = false;
}
}
}

Expand Down

0 comments on commit 2e6870b

Please sign in to comment.