Skip to content

Commit

Permalink
coro: add ThreadPool for mixing blocking code
Browse files Browse the repository at this point in the history
  • Loading branch information
Cloudef committed Jun 19, 2024
1 parent f095a59 commit 4cbbd38
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 19 deletions.
31 changes: 31 additions & 0 deletions docs/pages/coro-blocking-code.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# CORO API

## Mixing blocking code

Sometimes it's not feasible to rewrite blocking code so that it plays nice with the `coro.Scheduler`.
In that kind of scenario it is possible to use `coro.ThreadPool` to allow tasks to yield until blocking code
finishes on a worker thread.

### Example

```zig
fn blockingCode() u32 {
std.time.sleep(1 * std.time.ns_per_s);
return 69;
}
fn task(pool: *ThreadPool) !void {
const ret = try pool.yieldForCompletition(blocking, .{});
try std.testing.expectEqual(69, ret);
}
var pool: ThreadPool = .{};
defer pool.deinit(); // pool must always be destroyed before scheduler
try pool.start(std.testing.allocator, 0);
var scheduler = try Scheduler.init(std.testing.allocator, .{});
defer scheduler.deinit();
_ = try scheduler.spawn(task, .{&pool}, .{});
try scheduler.run();
```

4 changes: 4 additions & 0 deletions docs/vocs.config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ export default defineConfig({
text: 'Context switches',
link: '/coro-context-switches',
},
{
text: 'Mixing blocking code',
link: '/coro-blocking-code',
},
],
},
],
Expand Down
40 changes: 33 additions & 7 deletions src/aio.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
const std = @import("std");

pub const InitError = error{
Overflow,
OutOfMemory,
PermissionDenied,
ProcessQuotaExceeded,
Expand All @@ -16,7 +15,7 @@ pub const InitError = error{
};

pub const QueueError = error{
Overflow,
OutOfMemory,
SubmissionQueueFull,
};

Expand Down Expand Up @@ -101,22 +100,38 @@ pub inline fn multi(operations: anytype) (ImmediateError || error{SomeOperationF
}

/// Completes a single operation immediately, blocks until complete
pub inline fn single(operation: anytype) (ImmediateError || OperationError)!void {
pub inline fn single(operation: anytype) (ImmediateError || @TypeOf(operation).Error)!void {
var op: @TypeOf(operation) = operation;
var err: @TypeOf(operation).Error = error.Success;
op.out_error = &err;
_ = try complete(.{op});
if (err != error.Success) return err;
}

pub const EventSource = struct {
native: IO.EventSource,

pub inline fn init() InitError!@This() {
return .{ .native = try IO.EventSource.init() };
}

pub inline fn deinit(self: *@This()) void {
self.native.deinit();
self.* = undefined;
}

pub inline fn notify(self: *@This()) void {
self.native.notify();
}
};

const IO = switch (@import("builtin").target.os.tag) {
.linux => @import("aio/linux.zig"),
else => @compileError("unsupported os"),
};

const ops = @import("aio/ops.zig");
pub const Id = ops.Id;
pub const OperationError = ops.Operation.Error;
pub const Fsync = ops.Fsync;
pub const Read = ops.Read;
pub const Write = ops.Write;
Expand All @@ -137,6 +152,9 @@ pub const SymlinkAt = ops.SymlinkAt;
pub const ChildExit = ops.ChildExit;
pub const Socket = ops.Socket;
pub const CloseSocket = ops.CloseSocket;
pub const NotifyEventSource = ops.NotifyEventSource;
pub const WaitEventSource = ops.WaitEventSource;
pub const CloseEventSource = ops.CloseEventSource;

test "shared outputs" {
var tmp = std.testing.tmpDir(.{});
Expand Down Expand Up @@ -252,12 +270,11 @@ test "Timeout" {
test "LinkTimeout" {
var err: Timeout.Error = undefined;
var expired: bool = undefined;
const res = try complete(.{
const num_errors = try complete(.{
Timeout{ .ns = 2 * std.time.ns_per_s, .out_error = &err, .link_next = true },
LinkTimeout{ .ns = 1 * std.time.ns_per_s, .out_expired = &expired },
});
try std.testing.expectEqual(2, res.num_completed);
try std.testing.expectEqual(1, res.num_errors);
try std.testing.expectEqual(1, num_errors);
try std.testing.expectEqual(error.OperationCanceled, err);
try std.testing.expectEqual(true, expired);
}
Expand Down Expand Up @@ -356,3 +373,12 @@ test "Socket" {
});
try single(CloseSocket{ .socket = socket });
}

test "EventSource" {
const source = try EventSource.init();
try multi(.{
NotifyEventSource{ .source = source },
WaitEventSource{ .source = source, .link_next = true },
CloseEventSource{ .source = source },
});
}
39 changes: 35 additions & 4 deletions src/aio/linux.zig
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
const std = @import("std");
const aio = @import("../aio.zig");
const Operation = @import("ops.zig").Operation;
const ErrorUnion = @import("ops.zig").ErrorUnion;

pub const EventSource = struct {
fd: std.posix.fd_t,

pub inline fn init() !@This() {
return .{
.fd = std.posix.eventfd(0, std.os.linux.EFD.CLOEXEC) catch |err| return switch (err) {
error.SystemResources => error.SystemResources,
error.ProcessFdQuotaExceeded => error.ProcessQuotaExceeded,
error.SystemFdQuotaExceeded => error.SystemQuotaExceeded,
error.Unexpected => error.Unexpected,
},
};
}

pub inline fn deinit(self: *@This()) void {
std.posix.close(self.fd);
self.* = undefined;
}

pub inline fn notify(self: *@This()) void {
_ = std.posix.write(self.fd, &std.mem.toBytes(@as(u64, 1))) catch unreachable;
}
};

io: std.os.linux.IoUring,
ops: Pool(Operation.Union, u16),

pub fn init(allocator: std.mem.Allocator, n: u16) aio.InitError!@This() {
const n2 = try std.math.ceilPowerOfTwo(u16, n);
const n2 = std.math.ceilPowerOfTwo(u16, n) catch return error.SystemQuotaExceeded;
var io = try uring_init(n2);
errdefer io.deinit();
const ops = try Pool(Operation.Union, u16).init(allocator, n2);
Expand All @@ -22,7 +45,7 @@ pub fn deinit(self: *@This(), allocator: std.mem.Allocator) void {
}

inline fn queueOperation(self: *@This(), op: anytype) aio.QueueError!u16 {
const n = self.ops.next() orelse return error.Overflow;
const n = self.ops.next() orelse return error.OutOfMemory;
try uring_queue(&self.io, op, n);
const tag = @tagName(comptime Operation.tagFromPayloadType(@TypeOf(op.*)));
return self.ops.add(@unionInit(Operation.Union, tag, op.*)) catch unreachable;
Expand Down Expand Up @@ -63,7 +86,7 @@ pub fn complete(self: *@This(), mode: aio.Dynamic.CompletionMode) aio.Completion
}

pub fn immediate(comptime len: u16, work: anytype) aio.ImmediateError!u16 {
var io = try uring_init(try std.math.ceilPowerOfTwo(u16, len));
var io = try uring_init(std.math.ceilPowerOfTwo(u16, len) catch return error.SystemQuotaExceeded);
defer io.deinit();
inline for (&work.ops, 0..) |*op, idx| try uring_queue(&io, op, idx);
var num = try uring_submit(&io);
Expand Down Expand Up @@ -126,6 +149,9 @@ fn convertOpenFlags(flags: std.fs.File.OpenFlags) std.posix.O {
}

inline fn uring_queue(io: *std.os.linux.IoUring, op: anytype, user_data: u64) aio.QueueError!void {
const Trash = struct {
var u_64: u64 align(1) = undefined;
};
const RENAME_NOREPLACE = 1 << 0;
var sqe = switch (comptime Operation.tagFromPayloadType(@TypeOf(op.*))) {
.fsync => try io.fsync(user_data, op.file.handle, 0),
Expand Down Expand Up @@ -161,6 +187,9 @@ inline fn uring_queue(io: *std.os.linux.IoUring, op: anytype, user_data: u64) ai
.child_exit => try io.waitid(user_data, .PID, op.child, @constCast(&op._), std.posix.W.EXITED, 0),
.socket => try io.socket(user_data, op.domain, op.flags, op.protocol, 0),
.close_socket => try io.close(user_data, op.socket),
.notify_event_source => try io.write(user_data, op.source.native.fd, &std.mem.toBytes(@as(u64, 1)), 0),
.wait_event_source => try io.read(user_data, op.source.native.fd, .{ .buffer = std.mem.asBytes(&Trash.u_64) }, 0),
.close_event_source => try io.close(user_data, op.source.native.fd),
};
if (op.link_next) sqe.flags |= std.os.linux.IOSQE_IO_LINK;
if (@hasField(@TypeOf(op.*), "out_id")) {
Expand Down Expand Up @@ -355,6 +384,7 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe)
else => std.posix.unexpectedErrno(err),
},
.close_file, .close_dir, .close_socket => unreachable,
.notify_event_source, .wait_event_source, .close_event_source => unreachable,
.timeout => switch (err) {
.SUCCESS, .INTR, .INVAL, .AGAIN => unreachable,
.TIME => error.Success,
Expand Down Expand Up @@ -493,6 +523,7 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe)
},
.open_at => op.out_file.handle = cqe.res,
.close_file, .close_dir, .close_socket => {},
.notify_event_source, .wait_event_source, .close_event_source => {},
.timeout, .link_timeout => {},
.cancel => {},
.rename_at, .unlink_at, .mkdir_at, .symlink_at => {},
Expand Down
37 changes: 31 additions & 6 deletions src/aio/ops.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const std = @import("std");
const builtin = @import("builtin");
const aio = @import("../aio.zig");

// Virtual linked actions are possible with `nop` under io_uring :thinking:

Expand Down Expand Up @@ -261,6 +262,30 @@ pub const CloseSocket = struct {
link_next: bool = false,
};

pub const NotifyEventSource = struct {
pub const Error = SharedError;
source: aio.EventSource,
out_error: ?*Error = null,
counter: Counter = .nop,
link_next: bool = false,
};

pub const WaitEventSource = struct {
pub const Error = SharedError;
source: aio.EventSource,
out_error: ?*Error = null,
counter: Counter = .nop,
link_next: bool = false,
};

pub const CloseEventSource = struct {
pub const Error = SharedError;
source: aio.EventSource,
out_error: ?*Error = null,
counter: Counter = .nop,
link_next: bool = false,
};

pub const Operation = enum {
fsync,
read,
Expand All @@ -282,6 +307,9 @@ pub const Operation = enum {
child_exit,
socket,
close_socket,
notify_event_source,
wait_event_source,
close_event_source,

pub const map = std.enums.EnumMap(@This(), type).init(.{
.fsync = Fsync,
Expand All @@ -304,6 +332,9 @@ pub const Operation = enum {
.child_exit = ChildExit,
.socket = Socket,
.close_socket = CloseSocket,
.notify_event_source = NotifyEventSource,
.wait_event_source = WaitEventSource,
.close_event_source = CloseEventSource,
});

pub fn tagFromPayloadType(comptime Op: type) @This() {
Expand All @@ -316,12 +347,6 @@ pub const Operation = enum {
unreachable;
}

pub const Error = blk: {
var set = error{};
for (Operation.map.values) |v| set = set || v.Error;
break :blk set;
};

pub const Union = blk: {
var fields: []const std.builtin.Type.UnionField = &.{};
for (Operation.map.values, 0..) |v, idx| fields = fields ++ .{.{
Expand Down
Loading

0 comments on commit 4cbbd38

Please sign in to comment.