Skip to content

Commit

Permalink
coro: refactor api, fallback to aio if not in task
Browse files Browse the repository at this point in the history
  • Loading branch information
Cloudef committed Jun 19, 2024
1 parent df2d84f commit 759312a
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 62 deletions.
4 changes: 2 additions & 2 deletions examples/aio_static.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub fn main() !void {
var buf2: [4096]u8 = undefined;
var len2: usize = 0;

const ret = try aio.complete(.{
const num_errors = try aio.complete(.{
aio.Read{
.file = f,
.buffer = &buf,
Expand All @@ -28,5 +28,5 @@ pub fn main() !void {

log.info("{s}", .{buf[0..len]});
log.info("{s}", .{buf2[0..len2]});
log.info("{}", .{ret});
log.info("{}", .{num_errors});
}
17 changes: 11 additions & 6 deletions examples/coro.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ pub const aio_coro_options: coro.Options = .{
.debug = false, // set to true to enable debug logs
};

fn server() !void {
const Yield = enum {
server_ready,
};

fn server(client_task: coro.Task) !void {
var socket: std.posix.socket_t = undefined;
try coro.io.single(aio.Socket{
.domain = std.posix.AF.INET,
Expand All @@ -22,6 +26,8 @@ fn server() !void {
try std.posix.bind(socket, &address.any, address.getOsSockLen());
try std.posix.listen(socket, 128);

coro.wakeupFromState(client_task, Yield.server_ready);

var client_sock: std.posix.socket_t = undefined;
try coro.io.single(aio.Accept{ .socket = socket, .out_socket = &client_sock });

Expand All @@ -43,9 +49,6 @@ fn server() !void {
}

fn client() !void {
log.info("waiting 2 secs, to give time for the server to spin up", .{});
try coro.io.single(aio.Timeout{ .ts = .{ .sec = 2, .nsec = 0 } });

var socket: std.posix.socket_t = undefined;
try coro.io.single(aio.Socket{
.domain = std.posix.AF.INET,
Expand All @@ -54,6 +57,8 @@ fn client() !void {
.out_socket = &socket,
});

coro.yield(Yield.server_ready);

const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 1327);
try coro.io.single(aio.Connect{
.socket = socket,
Expand Down Expand Up @@ -83,7 +88,7 @@ pub fn main() !void {
defer _ = gpa.deinit();
var scheduler = try coro.Scheduler.init(gpa.allocator(), .{});
defer scheduler.deinit();
_ = try scheduler.spawn(server, .{}, .{});
_ = try scheduler.spawn(client, .{}, .{});
const client_task = try scheduler.spawn(client, .{}, .{});
_ = try scheduler.spawn(server, .{client_task}, .{});
try scheduler.run();
}
6 changes: 3 additions & 3 deletions src/aio.zig
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ pub const Dynamic = struct {

/// Completes a list of operations immediately, blocks until complete
/// For error handling you must check the `out_error` field in the operation
pub inline fn complete(operations: anytype) ImmediateError!CompletionResult {
/// Returns the number of errors occured, 0 if there were no errors
pub inline fn complete(operations: anytype) ImmediateError!u16 {
const ti = @typeInfo(@TypeOf(operations));
if (comptime ti == .Struct and ti.Struct.is_tuple) {
if (comptime operations.len == 0) @compileError("no work to be done");
Expand All @@ -96,8 +97,7 @@ pub inline fn complete(operations: anytype) ImmediateError!CompletionResult {
/// Completes a list of operations immediately, blocks until complete
/// Returns `error.SomeOperationFailed` if any operation failed
pub inline fn multi(operations: anytype) (ImmediateError || error{SomeOperationFailed})!void {
const res = try complete(operations);
if (res.num_errors > 0) return error.SomeOperationFailed;
if (try complete(operations) > 0) return error.SomeOperationFailed;
}

/// Completes a single operation immediately, blocks until complete
Expand Down
10 changes: 4 additions & 6 deletions src/aio/linux.zig
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,25 @@ pub fn complete(self: *@This(), mode: aio.Dynamic.CompletionMode) aio.Completion
return result;
}

pub fn immediate(comptime len: u16, work: anytype) aio.ImmediateError!aio.CompletionResult {
pub fn immediate(comptime len: u16, work: anytype) aio.ImmediateError!u16 {
var io = try uring_init(try std.math.ceilPowerOfTwo(u16, len));
defer io.deinit();
inline for (&work.ops, 0..) |*op, idx| try uring_queue(&io, op, idx);
var num = try uring_submit(&io);
const submitted = num;
var result: aio.CompletionResult = .{};
var num_errors: u16 = 0;
var cqes: [len]std.os.linux.io_uring_cqe = undefined;
while (num > 0) {
const n = try uring_copy_cqes(&io, &cqes, num);
for (cqes[0..n]) |*cqe| {
inline for (&work.ops, 0..) |*op, idx| if (idx == cqe.user_data) {
uring_handle_completion(op, cqe) catch {
result.num_errors += 1;
num_errors += 1;
};
};
}
num -= n;
}
result.num_completed = submitted - num;
return result;
return num_errors;
}

inline fn uring_init(n: u16) aio.InitError!std.os.linux.IoUring {
Expand Down
84 changes: 39 additions & 45 deletions src/coro.zig
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ pub const io = struct {
/// The IO operations can be cancelled by calling `wakeup`
/// For error handling you must check the `out_error` field in the operation
/// Returns the number of errors occured, 0 if there were no errors
pub inline fn complete(operations: anytype) aio.QueueError!u16 {
pub inline fn complete(operations: anytype) aio.ImmediateError!u16 {
if (Fiber.current()) |fiber| {
var task: *Scheduler.TaskState = @ptrFromInt(fiber.getUserDataPtr().*);
var task: *TaskState = @ptrFromInt(fiber.getUserDataPtr().*);

const State = struct { old_err: ?*anyerror, old_id: ?*aio.Id, id: aio.Id, err: anyerror };
var state: [operations.len]State = undefined;
Expand Down Expand Up @@ -78,20 +78,20 @@ pub const io = struct {
}
return num_errors;
} else {
unreachable; // this io function is only meant to be used in coroutines!
return aio.complete(operations);
}
}

/// Completes a list of operations immediately, blocks until complete
/// The IO operations can be cancelled by calling `wakeupFromIo`, or doing `aio.Cancel`
/// Returns `error.SomeOperationFailed` if any operation failed
pub inline fn multi(operations: anytype) (aio.QueueError || error{SomeOperationFailed})!void {
pub inline fn multi(operations: anytype) (aio.ImmediateError || error{SomeOperationFailed})!void {
if (try complete(operations) > 0) return error.SomeOperationFailed;
}

/// Completes a single operation immediately, blocks the coroutine until complete
/// The IO operation can be cancelled by calling `wakeupFromIo`, or doing `aio.Cancel`
pub fn single(operation: anytype) (aio.QueueError || aio.OperationError)!void {
pub fn single(operation: anytype) (aio.ImmediateError || aio.OperationError)!void {
var op: @TypeOf(operation) = operation;
var err: @TypeOf(operation).Error = error.Success;
op.out_error = &err;
Expand All @@ -106,21 +106,21 @@ pub inline fn yield(state: anytype) void {
}

/// Wakeups a task from a yielded state, no-op if `state` does not match the current yielding state
pub inline fn wakeupFromState(task: Scheduler.Task, state: anytype) void {
pub inline fn wakeupFromState(task: Task, state: anytype) void {
const node: *Scheduler.TaskNode = @ptrCast(task);
if (node.data.marked_for_reap) return;
privateWakeup(&node.data, @enumFromInt(std.meta.fields(YieldState).len + @intFromEnum(state)));
}

/// Wakeups a task from IO by canceling the current IO operations for that task
pub inline fn wakeupFromIo(task: Scheduler.Task) void {
pub inline fn wakeupFromIo(task: Task) void {
const node: *Scheduler.TaskNode = @ptrCast(task);
if (node.data.marked_for_reap) return;
privateWakeup(&node.data, .io);
}

/// Wakeups a task regardless of the current yielding state
pub inline fn wakeup(task: Scheduler.Task) void {
pub inline fn wakeup(task: Task) void {
const node: *Scheduler.TaskNode = @ptrCast(task);
if (node.data.marked_for_reap) return;
// do not wakeup from io_cancel state as that can potentially lead to memory corruption
Expand All @@ -137,74 +137,72 @@ const YieldState = enum(u8) {

inline fn privateYield(state: YieldState) void {
if (Fiber.current()) |fiber| {
var task: *Scheduler.TaskState = @ptrFromInt(fiber.getUserDataPtr().*);
var task: *TaskState = @ptrFromInt(fiber.getUserDataPtr().*);
std.debug.assert(task.yield_state == .not_yielding);
task.yield_state = state;
debug("yielding: {}", .{task});
Fiber.yield();
} else {
unreachable; // yield is only meant to be used in coroutines!
unreachable; // yield can only be used from a task
}
}

inline fn privateWakeup(task: *Scheduler.TaskState, state: YieldState) void {
inline fn privateWakeup(task: *TaskState, state: YieldState) void {
if (task.yield_state != state) return;
debug("waking up from yield: {}", .{task});
task.yield_state = .not_yielding;
task.fiber.switchTo();
}

const TaskState = struct {
fiber: *Fiber,
stack: ?Fiber.Stack = null,
marked_for_reap: bool = false,
io: *aio.Dynamic,
io_counter: u16 = 0,
yield_state: YieldState = .not_yielding,

pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
if (self.io_counter > 0) {
try writer.print("{x}: {s}, {} ops left", .{ @intFromPtr(self.fiber), @tagName(self.yield_state), self.io_counter });
} else {
try writer.print("{x}: {s}", .{ @intFromPtr(self.fiber), @tagName(self.yield_state) });
}
}

fn deinit(self: *@This(), allocator: std.mem.Allocator) void {
// we can only safely deinit the task when it is not doing IO
// otherwise for example io_uring might write to invalid memory address
std.debug.assert(self.yield_state != .io and self.yield_state != .io_cancel);
if (self.stack) |stack| allocator.free(stack);
self.* = undefined;
}
};

pub const Task = *align(@alignOf(Scheduler.TaskNode)) anyopaque;

/// Runtime for asynchronous IO tasks
pub const Scheduler = struct {
allocator: std.mem.Allocator,
io: aio.Dynamic,
tasks: std.DoublyLinkedList(TaskState) = .{},
pending_for_reap: bool = false,

const TaskState = struct {
fiber: *Fiber,
stack: ?Fiber.Stack = null,
marked_for_reap: bool = false,
io: *aio.Dynamic,
io_counter: u16 = 0,
yield_state: YieldState = .not_yielding,

pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
if (self.io_counter > 0) {
try writer.print("{x}: {s}, {} ops left", .{ @intFromPtr(self.fiber), @tagName(self.yield_state), self.io_counter });
} else {
try writer.print("{x}: {s}", .{ @intFromPtr(self.fiber), @tagName(self.yield_state) });
}
}

fn deinit(self: *@This(), allocator: std.mem.Allocator) void {
if (Fiber.current()) |_| unreachable; // do not call deinit from a task
// we can only safely deinit the task when it is not doing IO
// otherwise for example io_uring might write to invalid memory address
std.debug.assert(self.yield_state != .io and self.yield_state != .io_cancel);
if (self.stack) |stack| allocator.free(stack);
self.* = undefined;
}
};

const TaskNode = std.DoublyLinkedList(TaskState).Node;
pub const Task = *align(@alignOf(TaskNode)) anyopaque;

pub const InitOptions = struct {
/// This is a hint, the implementation makes the final call
io_queue_entries: u16 = options.io_queue_entries,
};

pub fn init(allocator: std.mem.Allocator, opts: InitOptions) !@This() {
if (Fiber.current()) |_| unreachable; // do not call init from a task
return .{
.allocator = allocator,
.io = try aio.Dynamic.init(allocator, opts.io_queue_entries),
};
}

pub fn reapAll(self: *@This()) void {
if (Fiber.current()) |_| unreachable; // do not call reapAll from a task
var maybe_node = self.tasks.first;
while (maybe_node) |node| {
node.data.marked_for_reap = true;
Expand All @@ -214,7 +212,6 @@ pub const Scheduler = struct {
}

fn privateReap(self: *@This(), node: *TaskNode) bool {
if (Fiber.current()) |_| unreachable; // do not call reap from a task
if (node.data.yield_state == .io or node.data.yield_state == .io_cancel) {
debug("task is pending on io, reaping later: {}", .{node.data});
if (node.data.yield_state == .io) privateWakeup(&node.data, .io); // cancel io
Expand All @@ -235,7 +232,6 @@ pub const Scheduler = struct {
}

pub fn deinit(self: *@This()) void {
if (Fiber.current()) |_| unreachable; // do not call deinit from a task
// destroy io backend first to make sure we can destroy the tasks safely
self.io.deinit(self.allocator);
while (self.tasks.pop()) |node| {
Expand All @@ -254,7 +250,7 @@ pub const Scheduler = struct {
} else {
@call(.auto, func, args);
}
var task: *Scheduler.TaskState = @ptrFromInt(Fiber.current().?.getUserDataPtr().*);
var task: *TaskState = @ptrFromInt(Fiber.current().?.getUserDataPtr().*);
task.marked_for_reap = true;
self.pending_for_reap = true;
debug("finished: {}", .{task});
Expand All @@ -271,7 +267,6 @@ pub const Scheduler = struct {

/// Spawns a new task, the task may do local IO operations which will not block the whole process using the `io` namespace functions
pub fn spawn(self: *@This(), comptime func: anytype, args: anytype, opts: SpawnOptions) SpawnError!Task {
if (Fiber.current()) |_| unreachable; // do not call spawn from a task
const stack = switch (opts.stack) {
.unmanaged => |buf| buf,
.managed => |sz| try self.allocator.alignedAlloc(u8, Fiber.stack_alignment, sz),
Expand Down Expand Up @@ -312,7 +307,6 @@ pub const Scheduler = struct {

/// Processes pending IO and reaps dead tasks
pub fn tick(self: *@This(), mode: aio.Dynamic.CompletionMode) !void {
if (Fiber.current()) |_| unreachable; // do not call tick from a task
try self.tickIo(mode);
if (self.pending_for_reap) {
var num_unreaped: usize = 0;
Expand Down

0 comments on commit 759312a

Please sign in to comment.