diff --git a/compiler/compile.go b/compiler/compile.go index de405e1..abde5a1 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" + "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/callgraph/cha" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/ssa" @@ -232,24 +233,26 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr } colorsByDecl[decl] = color } + for _, f := range p.Syntax { for _, anydecl := range f.Decls { decl, ok := anydecl.(*ast.FuncDecl) if !ok { continue } - color, ok := colorsByDecl[decl] if !ok { continue } - // Reject certain language features for now. if err := unsupported(decl, p.TypesInfo); err != nil { return err } - - gen.Decls = append(gen.Decls, c.compileFunction(p, decl, color)) + scope := &scope{ + colors: colorsByDecl, + objectIdent: 0, + } + gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color)) } } @@ -273,18 +276,44 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr return nil } -func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl { - log.Printf("compiling function %s %s", p.Name, fn.Name) +type scope struct { + colors map[ast.Node]*types.Signature + // Index used to generate unique object identifiers within the scope of a + // function. + // + // The field is reset to zero after compiling function declarations because + // we don't need globally unique identifiers for local variables. + objectIdent int +} + +func (scope *scope) newObjectIdent() *ast.Ident { + ident := scope.objectIdent + scope.objectIdent++ + return ast.NewIdent(fmt.Sprintf("_o%d", ident)) +} +func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl { + log.Printf("compiling function %s %s", p.Name, fn.Name) // Generate the coroutine function. At this stage, use the same name // as the source function (and require that the caller use build tags // to disambiguate function calls). - gen := &ast.FuncDecl{ + return &ast.FuncDecl{ Name: fn.Name, - Type: fn.Type, - Body: &ast.BlockStmt{}, + Type: funcTypeWithNamedResults(fn.Type), + Body: scope.compileFuncBody(p, fn.Type, fn.Body, color), + } +} + +func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *types.Signature) *ast.FuncLit { + log.Printf("compiling function literal %s", p.Name) + return &ast.FuncLit{ + Type: funcTypeWithNamedResults(fn.Type), + Body: scope.compileFuncBody(p, fn.Type, fn.Body, color), } +} +func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt { + gen := new(ast.BlockStmt) ctx := ast.NewIdent("_c") frame := ast.NewIdent("_f") fp := ast.NewIdent("_fp") @@ -294,7 +323,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type()) // _c := coroutine.LoadContext[R, S]() - gen.Body.List = append(gen.Body.List, &ast.AssignStmt{ + gen.List = append(gen.List, &ast.AssignStmt{ Lhs: []ast.Expr{ctx}, Tok: token.DEFINE, Rhs: []ast.Expr{ @@ -311,7 +340,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color }) // _f, _fp := _c.Push() - gen.Body.List = append(gen.Body.List, &ast.AssignStmt{ + gen.List = append(gen.List, &ast.AssignStmt{ Lhs: []ast.Expr{frame, fp}, Tok: token.DEFINE, Rhs: []ast.Expr{ @@ -321,8 +350,22 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color }, }) + body = astutil.Apply(body, + func(cursor *astutil.Cursor) bool { + switch n := cursor.Node().(type) { + case *ast.FuncLit: + color, ok := scope.colors[n] + if ok { + cursor.Replace(scope.compileFuncLit(p, n, color)) + } + } + return true + }, + nil, + ).(*ast.BlockStmt) + // Desugar statements in the tree. - fn.Body = desugar(fn.Body, p.TypesInfo).(*ast.BlockStmt) + body = desugar(body, p.TypesInfo).(*ast.BlockStmt) // Handle declarations. // @@ -337,17 +380,17 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color // declarations to the function prologue. We downgrade inline var decls and // assignments that use := to assignments that use =. Constant decls are // hoisted and also have their value assigned in the function prologue. - decls := extractDecls(fn, p.TypesInfo) - renameObjects(fn, p.TypesInfo, decls) + decls := extractDecls(body, p.TypesInfo) + renameObjects(body, p.TypesInfo, decls, scope) for _, decl := range decls { - gen.Body.List = append(gen.Body.List, &ast.DeclStmt{Decl: decl}) + gen.List = append(gen.List, &ast.DeclStmt{Decl: decl}) } - removeDecls(fn) + removeDecls(body) // Collect params/results/variables that need to be saved/restored. var saveAndRestoreNames []*ast.Ident var saveAndRestoreTypes []types.Type - scanFuncTypeIdentifiers(fn.Type, func(name *ast.Ident) { + scanFuncTypeIdentifiers(typ, func(name *ast.Ident) { saveAndRestoreNames = append(saveAndRestoreNames, name) saveAndRestoreTypes = append(saveAndRestoreTypes, p.TypesInfo.TypeOf(name)) }) @@ -398,7 +441,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color }, ) } - gen.Body.List = append(gen.Body.List, &ast.IfStmt{ + gen.List = append(gen.List, &ast.IfStmt{ Cond: &ast.BinaryExpr{ X: &ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")}, Op: token.GTR, /* > */ @@ -419,7 +462,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color }, }) } - gen.Body.List = append(gen.Body.List, &ast.DeferStmt{ + gen.List = append(gen.List, &ast.DeferStmt{ Call: &ast.CallExpr{ Fun: &ast.FuncLit{ Type: &ast.FuncType{}, @@ -447,11 +490,21 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color }, }) - spans := trackDispatchSpans(fn.Body) - - compiledBody := compileDispatch(fn.Body, spans).(*ast.BlockStmt) - - gen.Body.List = append(gen.Body.List, compiledBody.List...) - + spans := trackDispatchSpans(body) + compiledBody := compileDispatch(body, spans).(*ast.BlockStmt) + gen.List = append(gen.List, compiledBody.List...) + + // If the function returns one or more values, it must end with a return statement; + // we inject it if the function body does not already has one. + if typ.Results != nil && len(typ.Results.List) > 0 { + needsReturn := len(gen.List) == 0 + if !needsReturn { + _, endsWithReturn := gen.List[len(gen.List)-1].(*ast.ReturnStmt) + needsReturn = !endsWithReturn + } + if needsReturn { + gen.List = append(gen.List, &ast.ReturnStmt{}) + } + } return gen } diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index 9956222..614f1bf 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -121,8 +121,26 @@ func TestCoroutineYield(t *testing.T) { }, { - name: "range over closure", - coro: func() { Range10Closure() }, + name: "range over closure capturing values", + coro: Range10ClosureCapturingValues, + yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + + { + name: "range over closure capturing pointers", + coro: Range10ClosureCapturingPointers, + yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + + { + name: "range over closure capturing heterogenous values", + coro: Range10ClosureHeterogenousCapture, + yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + + { + name: "range with heterogenous values", + coro: Range10Heterogenous, yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, }, diff --git a/compiler/decls.go b/compiler/decls.go index 8449436..a1ef552 100644 --- a/compiler/decls.go +++ b/compiler/decls.go @@ -4,7 +4,6 @@ import ( "go/ast" "go/token" "go/types" - "strconv" "golang.org/x/tools/go/ast/astutil" ) @@ -21,9 +20,13 @@ import ( // Note that declarations are extracted from all nested scopes within the // function body, so there may be duplicate identifiers. Identifiers can be // disambiguated using (*types.Info).ObjectOf(ident). -func extractDecls(fn *ast.FuncDecl, info *types.Info) (decls []*ast.GenDecl) { - ast.Inspect(fn.Body, func(node ast.Node) bool { +func extractDecls(tree ast.Node, info *types.Info) (decls []*ast.GenDecl) { + ast.Inspect(tree, func(node ast.Node) bool { switch n := node.(type) { + case *ast.FuncLit: + // Stop when we encounter a function listeral so we don't hoist its + // local variables into the scope of its parent function. + return false case *ast.GenDecl: // const, var, type if n.Tok == token.TYPE || n.Tok == token.CONST { decls = append(decls, n) @@ -91,9 +94,8 @@ func extractDecls(fn *ast.FuncDecl, info *types.Info) (decls []*ast.GenDecl) { // renameObjects renames types, constants and variables declared within // a function. Each is given a unique name, so that declarations are safe // to hoist into the function prologue. -func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl) { +func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, scope *scope) { // Scan decls to find objects, giving each new object a unique name. - var id int newNames := map[types.Object]*ast.Ident{} for _, decl := range decls { for _, spec := range decl.Specs { @@ -102,15 +104,13 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl) { if s.Name.Name == "_" { continue } - newNames[info.ObjectOf(s.Name)] = ast.NewIdent("_o" + strconv.Itoa(id)) - id++ + newNames[info.ObjectOf(s.Name)] = scope.newObjectIdent() case *ast.ValueSpec: // const/var for _, name := range s.Names { if name.Name == "_" { continue } - newNames[info.ObjectOf(name)] = ast.NewIdent("_o" + strconv.Itoa(id)) - id++ + newNames[info.ObjectOf(name)] = scope.newObjectIdent() } } } @@ -135,6 +135,8 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl) { func removeDecls(tree ast.Node) { astutil.Apply(tree, func(cursor *astutil.Cursor) bool { switch n := cursor.Node().(type) { + case *ast.FuncLit: + return false case *ast.AssignStmt: if n.Tok == token.DEFINE { if _, ok := cursor.Parent().(*ast.TypeSwitchStmt); ok { diff --git a/compiler/function.go b/compiler/function.go index 75104c8..48c778f 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -114,16 +114,19 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt return cmp.Compare(f1.Name(), f2.Name()) }) - for index, anonFunc := range anonFuncs { - _, colored := colors[anonFunc] - if colored { - // Colored functions (those rewritten into coroutines) have a - // deferred anonymous function injected at the beginning to perform - // stack unwinding, which takes the ".func1" name. - index++ - } - name = anonFuncLinkName(name, index) - generateFunctypesInit(pkg, anonFunc, init, name, colors) + index := 0 + // Colored functions (those rewritten into coroutines) have a + // deferred anonymous function injected at the beginning to perform + // stack unwinding, which takes the ".func1" name. + _, colored := colors[fn] + if colored { + index++ + } + + for _, anonFunc := range anonFuncs { + index++ + anonFuncName := anonFuncLinkName(name, index) + generateFunctypesInit(pkg, anonFunc, init, anonFuncName, colors) } } @@ -133,5 +136,5 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt // The function works with multiple levels of nesting as each level adds another // ".func" suffix, with the index being local to the parent scope. func anonFuncLinkName(base string, index int) string { - return fmt.Sprintf("%s.func%d", base, index+1) + return fmt.Sprintf("%s.func%d", base, index) } diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 5ade359..613c721 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -279,7 +279,7 @@ func RangeTripleFuncValue(n int) { Range(n, f) } -func Range10Closure() { +func Range10ClosureCapturingValues() { i := 0 n := 10 f := func() bool { @@ -295,6 +295,112 @@ func Range10Closure() { } } +func Range10ClosureCapturingPointers() { + i, n := 0, 10 + p := &i + q := &n + f := func() bool { + if *p < *q { + coroutine.Yield[int, any](*p) + (*p)++ + return true + } + return false + } + + for f() { + } +} + +func Range10ClosureHeterogenousCapture() { + var ( + a int8 = 0 + b int16 = 1 + c int32 = 2 + d int64 = 3 + e uint8 = 4 + f uint16 = 5 + g uint32 = 6 + h uint64 = 7 + i uintptr = 8 + j = func() int { return int(i) + 1 } + ) + + n := 0 + x := func() bool { + var v int + switch n { + case 0: + v = int(a) + case 1: + v = int(b) + case 2: + v = int(c) + case 3: + v = int(d) + case 4: + v = int(e) + case 5: + v = int(f) + case 6: + v = int(g) + case 7: + v = int(h) + case 8: + v = int(i) + case 9: + v = j() + } + coroutine.Yield[int, any](v) + n++ + return n < 10 + } + + for x() { + } +} + +func Range10Heterogenous() { + var ( + a int8 = 0 + b int16 = 1 + c int32 = 2 + d int64 = 3 + e uint8 = 4 + f uint16 = 5 + g uint32 = 6 + h uint64 = 7 + i uintptr = 8 + ) + + for n := 0; n < 10; n++ { + var v int + switch n { + case 0: + v = int(a) + case 1: + v = int(b) + case 2: + v = int(c) + case 3: + v = int(d) + case 4: + v = int(e) + case 5: + v = int(f) + case 6: + v = int(g) + case 7: + v = int(h) + case 8: + v = int(i) + case 9: + v = int(n) + } + coroutine.Yield[int, any](v) + } +} + func Select(n int) { select { default: diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index cd7c25f..c91e794 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -2244,47 +2244,63 @@ func RangeTripleFuncValue(n int) { switch { case _f.IP < 2: _o0 = func(i int) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + if _f.IP > 0 { + if _v := _f.Get(0); _v != nil { + i = _v.(int) + } + } + defer func() { + if _c.Unwinding() { + _f.Set(0, i) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() coroutine.Yield[int, any](3 * i) } _f.IP = 2 fallthrough case _f.IP < 3: + Range(n, _o0) } } -func Range10Closure() { +func Range10ClosureCapturingValues() { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() - var _o0 int var _o1 int - var _o2 func() bool - var _o3 bool + var _o2 int + var _o3 func() bool var _o4 bool + var _o5 bool if _f.IP > 0 { if _v := _f.Get(0); _v != nil { - _o0 = _v.(int) + _o1 = _v.(int) } if _v := _f.Get(1); _v != nil { - _o1 = _v.(int) + _o2 = _v.(int) } if _v := _f.Get(2); _v != nil { - _o2 = _v.(func() bool) + _o3 = _v.(func() bool) } if _v := _f.Get(3); _v != nil { - _o3 = _v.(bool) + _o4 = _v.(bool) } if _v := _f.Get(4); _v != nil { - _o4 = _v.(bool) + _o5 = _v.(bool) } } defer func() { if _c.Unwinding() { - _f.Set(0, _o0) - _f.Set(1, _o1) - _f.Set(2, _o2) - _f.Set(3, _o3) - _f.Set(4, _o4) + _f.Set(0, _o1) + _f.Set(1, _o2) + _f.Set(2, _o3) + _f.Set(3, _o4) + _f.Set(4, _o5) _c.Store(_fp, _f) } else { _c.Pop() @@ -2292,21 +2308,61 @@ func Range10Closure() { }() switch { case _f.IP < 2: - _o0 = 0 + _o1 = 0 _f.IP = 2 fallthrough case _f.IP < 3: - _o1 = 10 + _o2 = 10 _f.IP = 3 fallthrough case _f.IP < 4: - _o2 = func() bool { - if _o0 < _o1 { - coroutine.Yield[int, any](_o0) - _o0++ - return true + _o3 = func() (_ bool) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 bool + if _f.IP > 0 { + if _v := _f.Get(0); _v != nil { + _o0 = _v.(bool) + } + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _o0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 5: + switch { + case _f.IP < 2: + _o0 = _o1 < _o2 + _f.IP = 2 + fallthrough + case _f.IP < 5: + if _o0 { + switch { + case _f.IP < 3: + coroutine.Yield[int, any](_o1) + _f.IP = 3 + fallthrough + case _f.IP < 4: + _o1++ + _f.IP = 4 + fallthrough + case _f.IP < 5: + return true + } + } + } + _f.IP = 5 + fallthrough + case _f.IP < 6: + + return false } - return false + return } _f.IP = 4 fallthrough @@ -2315,15 +2371,471 @@ func Range10Closure() { for ; ; _f.IP = 4 { switch { case _f.IP < 5: - _o3 = _o2() + _o4 = _o3() + _f.IP = 5 + fallthrough + case _f.IP < 6: + _o5 = !_o4 + _f.IP = 6 + fallthrough + case _f.IP < 7: + if _o5 { + break _l0 + } + } + } + } +} + +func Range10ClosureCapturingPointers() { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o1 int + var _o2 int + var _o3 *int + var _o4 *int + var _o5 func() bool + var _o6 bool + var _o7 bool + if _f.IP > 0 { + if _v := _f.Get(0); _v != nil { + _o1 = _v.(int) + } + if _v := _f.Get(1); _v != nil { + _o2 = _v.(int) + } + if _v := _f.Get(2); _v != nil { + _o3 = _v.(*int) + } + if _v := _f.Get(3); _v != nil { + _o4 = _v.(*int) + } + if _v := _f.Get(4); _v != nil { + _o5 = _v.(func() bool) + } + if _v := _f.Get(5); _v != nil { + _o6 = _v.(bool) + } + if _v := _f.Get(6); _v != nil { + _o7 = _v.(bool) + } + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _o1) + _f.Set(1, _o2) + _f.Set(2, _o3) + _f.Set(3, _o4) + _f.Set(4, _o5) + _f.Set(5, _o6) + _f.Set(6, _o7) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _o1, _o2 = 0, 10 + _f.IP = 2 + fallthrough + case _f.IP < 3: + _o3 = &_o1 + _f.IP = 3 + fallthrough + case _f.IP < 4: + _o4 = &_o2 + _f.IP = 4 + fallthrough + case _f.IP < 5: + _o5 = func() (_ bool) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 bool + if _f.IP > 0 { + if _v := _f.Get(0); _v != nil { + _o0 = _v.(bool) + } + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _o0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 5: + switch { + case _f.IP < 2: + _o0 = *_o3 < *_o4 + _f.IP = 2 + fallthrough + case _f.IP < 5: + if _o0 { + switch { + case _f.IP < 3: + coroutine.Yield[int, any](*_o3) + _f.IP = 3 + fallthrough + case _f.IP < 4: + (*_o3)++ + _f.IP = 4 + fallthrough + case _f.IP < 5: + return true + } + } + } _f.IP = 5 fallthrough case _f.IP < 6: - _o4 = !_o3 + + return false + } + return + } + _f.IP = 5 + fallthrough + case _f.IP < 8: + _l0: + for ; ; _f.IP = 5 { + switch { + case _f.IP < 6: + _o6 = _o5() _f.IP = 6 fallthrough case _f.IP < 7: - if _o4 { + _o7 = !_o6 + _f.IP = 7 + fallthrough + case _f.IP < 8: + if _o7 { + break _l0 + } + } + } + } +} + +func Range10ClosureHeterogenousCapture() { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o12 int8 + var _o13 int16 + var _o14 int32 + var _o15 int64 + var _o16 uint8 + var _o17 uint16 + var _o18 uint32 + var _o19 uint64 + var _o20 uintptr + var _o21 func() int + var _o22 int + var _o23 func() bool + var _o24 bool + var _o25 bool + if _f.IP > 0 { + if _v := _f.Get(0); _v != nil { + _o12 = _v.(int8) + } + if _v := _f.Get(1); _v != nil { + _o13 = _v.(int16) + } + if _v := _f.Get(2); _v != nil { + _o14 = _v.(int32) + } + if _v := _f.Get(3); _v != nil { + _o15 = _v.(int64) + } + if _v := _f.Get(4); _v != nil { + _o16 = _v.(uint8) + } + if _v := _f.Get(5); _v != nil { + _o17 = _v.(uint16) + } + if _v := _f.Get(6); _v != nil { + _o18 = _v.(uint32) + } + if _v := _f.Get(7); _v != nil { + _o19 = _v.(uint64) + } + if _v := _f.Get(8); _v != nil { + _o20 = _v.(uintptr) + } + if _v := _f.Get(9); _v != nil { + _o21 = _v.(func() int) + } + if _v := _f.Get(10); _v != nil { + + _o22 = _v.(int) + } + if _v := _f.Get(11); _v != nil { + _o23 = _v.(func() bool) + } + if _v := _f.Get(12); _v != nil { + _o24 = _v.(bool) + } + if _v := _f.Get(13); _v != nil { + _o25 = _v.(bool) + } + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _o12) + _f.Set(1, _o13) + _f.Set(2, _o14) + _f.Set(3, _o15) + _f.Set(4, _o16) + _f.Set(5, _o17) + _f.Set(6, _o18) + _f.Set(7, _o19) + _f.Set(8, _o20) + _f.Set(9, _o21) + _f.Set(10, _o22) + _f.Set(11, _o23) + _f.Set(12, _o24) + _f.Set(13, _o25) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 11: + switch { + case _f.IP < 2: + _o12 = 0 + _f.IP = 2 + fallthrough + case _f.IP < 3: + _o13 = 1 + _f.IP = 3 + fallthrough + case _f.IP < 4: + _o14 = 2 + _f.IP = 4 + fallthrough + case _f.IP < 5: + _o15 = 3 + _f.IP = 5 + fallthrough + case _f.IP < 6: + _o16 = 4 + _f.IP = 6 + fallthrough + case _f.IP < 7: + _o17 = 5 + _f.IP = 7 + fallthrough + case _f.IP < 8: + _o18 = 6 + _f.IP = 8 + fallthrough + case _f.IP < 9: + _o19 = 7 + _f.IP = 9 + fallthrough + case _f.IP < 10: + _o20 = 8 + _f.IP = 10 + fallthrough + case _f.IP < 11: + _o21 = func() int { return int(_o20) + 1 } + } + _f.IP = 11 + fallthrough + case _f.IP < 12: + + _o22 = 0 + _f.IP = 12 + fallthrough + case _f.IP < 13: + _o23 = func() (_ bool) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 int + var _o1 int + var _o2 bool + var _o3 bool + var _o4 bool + var _o5 bool + var _o6 bool + var _o7 bool + var _o8 bool + var _o9 bool + var _o10 bool + var _o11 bool + if _f.IP > 0 { + if _v := _f.Get(0); _v != nil { + _o0 = _v.(int) + } + if _v := _f.Get(1); _v != nil { + _o1 = _v.(int) + } + if _v := _f.Get(2); _v != nil { + _o2 = _v.(bool) + } + if _v := _f.Get(3); _v != nil { + _o3 = _v.(bool) + } + if _v := _f.Get(4); _v != nil { + _o4 = _v.(bool) + } + if _v := _f.Get(5); _v != nil { + _o5 = _v.(bool) + } + if _v := _f.Get(6); _v != nil { + _o6 = _v.(bool) + } + if _v := _f.Get(7); _v != nil { + _o7 = _v.(bool) + } + if _v := _f.Get(8); _v != nil { + _o8 = _v.(bool) + } + if _v := _f.Get(9); _v != nil { + _o9 = _v.(bool) + } + if _v := _f.Get(10); _v != nil { + _o10 = _v.(bool) + } + if _v := _f.Get(11); _v != nil { + _o11 = _v.(bool) + } + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _o0) + _f.Set(1, _o1) + _f.Set(2, _o2) + _f.Set(3, _o3) + _f.Set(4, _o4) + _f.Set(5, _o5) + _f.Set(6, _o6) + _f.Set(7, _o7) + _f.Set(8, _o8) + _f.Set(9, _o9) + _f.Set(10, _o10) + _f.Set(11, _o11) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _f.IP = 2 + fallthrough + case _f.IP < 23: + switch { + case _f.IP < 3: + _o1 = _o22 + _f.IP = 3 + fallthrough + case _f.IP < 23: + switch { + default: + switch { + case _f.IP < 4: + _o2 = _o1 == + 0 + _f.IP = 4 + fallthrough + case _f.IP < 23: + if _o2 { + _o0 = int(_o12) + } else { + _o3 = _o1 == + 1 + if _o3 { + _o0 = int(_o13) + } else { + _o4 = _o1 == + 2 + if _o4 { + _o0 = int(_o14) + } else { + _o5 = _o1 == + 3 + if _o5 { + _o0 = int(_o15) + } else { + _o6 = _o1 == + 4 + if _o6 { + _o0 = int(_o16) + } else { + _o7 = _o1 == + 5 + if _o7 { + _o0 = int(_o17) + } else { + _o8 = _o1 == + 6 + if _o8 { + _o0 = int(_o18) + } else { + _o9 = _o1 == + 7 + if _o9 { + _o0 = int(_o19) + } else { + _o10 = _o1 == + 8 + if _o10 { + _o0 = int(_o20) + } else { + _o11 = _o1 == + 9 + if _o11 { + _o0 = _o21() + } + } + } + } + } + } + } + } + } + } + } + } + } + _f.IP = 23 + fallthrough + case _f.IP < 24: + + coroutine.Yield[int, any](_o0) + _f.IP = 24 + fallthrough + case _f.IP < 25: + _o22++ + _f.IP = 25 + fallthrough + case _f.IP < 26: + return _o22 < 10 + } + return + } + _f.IP = 13 + fallthrough + case _f.IP < 16: + _l0: + for ; ; _f.IP = 13 { + switch { + case _f.IP < 14: + _o24 = _o23() + _f.IP = 14 + fallthrough + case _f.IP < 15: + _o25 = !_o24 + _f.IP = 15 + fallthrough + case _f.IP < 16: + if _o25 { break _l0 } } @@ -2331,6 +2843,288 @@ func Range10Closure() { } } +func Range10Heterogenous() { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 int8 + var _o1 int16 + var _o2 int32 + var _o3 int64 + var _o4 uint8 + var _o5 uint16 + var _o6 uint32 + var _o7 uint64 + var _o8 uintptr + var _o9 int + var _o10 bool + var _o11 int + var _o12 int + var _o13 bool + var _o14 bool + var _o15 bool + var _o16 bool + var _o17 bool + var _o18 bool + var _o19 bool + var _o20 bool + var _o21 bool + var _o22 bool + if _f.IP > 0 { + if _v := _f.Get(0); _v != nil { + _o0 = _v.(int8) + } + if _v := _f.Get(1); _v != nil { + _o1 = _v.(int16) + } + if _v := _f.Get(2); _v != nil { + _o2 = _v.(int32) + } + if _v := _f.Get(3); _v != nil { + _o3 = _v.(int64) + } + if _v := _f.Get(4); _v != nil { + _o4 = _v.(uint8) + } + if _v := _f.Get(5); _v != nil { + _o5 = _v.(uint16) + } + if _v := _f.Get(6); _v != nil { + _o6 = _v.(uint32) + } + if _v := _f.Get(7); _v != nil { + _o7 = _v.(uint64) + } + if _v := _f.Get(8); _v != nil { + _o8 = _v.(uintptr) + } + if _v := _f.Get(9); _v != nil { + + _o9 = _v.(int) + } + if _v := _f.Get(10); _v != nil { + _o10 = _v.(bool) + } + if _v := _f.Get(11); _v != nil { + _o11 = _v.(int) + } + if _v := _f.Get(12); _v != nil { + _o12 = _v.(int) + } + if _v := _f.Get(13); _v != nil { + _o13 = _v.(bool) + } + if _v := _f.Get(14); _v != nil { + _o14 = _v.(bool) + } + if _v := _f.Get(15); _v != nil { + _o15 = _v.(bool) + } + if _v := _f.Get(16); _v != nil { + _o16 = _v.(bool) + } + if _v := _f.Get(17); _v != nil { + _o17 = _v.(bool) + } + if _v := _f.Get(18); _v != nil { + _o18 = _v.(bool) + } + if _v := _f.Get(19); _v != nil { + _o19 = _v.(bool) + } + if _v := _f.Get(20); _v != nil { + _o20 = _v.(bool) + } + if _v := _f.Get(21); _v != nil { + _o21 = _v.(bool) + } + if _v := _f.Get(22); _v != nil { + _o22 = _v.(bool) + } + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _o0) + _f.Set(1, _o1) + _f.Set(2, _o2) + _f.Set(3, _o3) + _f.Set(4, _o4) + _f.Set(5, _o5) + _f.Set(6, _o6) + _f.Set(7, _o7) + _f.Set(8, _o8) + _f.Set(9, _o9) + _f.Set(10, _o10) + _f.Set(11, _o11) + _f.Set(12, _o12) + _f.Set(13, _o13) + _f.Set(14, _o14) + _f.Set(15, _o15) + _f.Set(16, _o16) + _f.Set(17, _o17) + _f.Set(18, _o18) + _f.Set(19, _o19) + _f.Set(20, _o20) + _f.Set(21, _o21) + _f.Set(22, _o22) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 10: + switch { + case _f.IP < 2: + _o0 = 0 + _f.IP = 2 + fallthrough + case _f.IP < 3: + _o1 = 1 + _f.IP = 3 + fallthrough + case _f.IP < 4: + _o2 = 2 + _f.IP = 4 + fallthrough + case _f.IP < 5: + _o3 = 3 + _f.IP = 5 + fallthrough + case _f.IP < 6: + _o4 = 4 + _f.IP = 6 + fallthrough + case _f.IP < 7: + _o5 = 5 + _f.IP = 7 + fallthrough + case _f.IP < 8: + _o6 = 6 + _f.IP = 8 + fallthrough + case _f.IP < 9: + _o7 = 7 + _f.IP = 9 + fallthrough + case _f.IP < 10: + _o8 = 8 + } + _f.IP = 10 + fallthrough + case _f.IP < 36: + switch { + case _f.IP < 11: + + _o9 = 0 + _f.IP = 11 + fallthrough + case _f.IP < 36: + _l0: + for ; ; _o9, _f.IP = _o9+1, 11 { + switch { + case _f.IP < 13: + switch { + case _f.IP < 12: + _o10 = !(_o9 < 10) + _f.IP = 12 + fallthrough + case _f.IP < 13: + if _o10 { + break _l0 + } + } + _f.IP = 13 + fallthrough + case _f.IP < 14: + _f.IP = 14 + fallthrough + case _f.IP < 35: + switch { + case _f.IP < 15: + _o12 = _o9 + _f.IP = 15 + fallthrough + case _f.IP < 35: + switch { + default: + switch { + case _f.IP < 16: + _o13 = _o12 == + 0 + _f.IP = 16 + fallthrough + case _f.IP < 35: + if _o13 { + _o11 = int(_o0) + } else { + _o14 = _o12 == + 1 + if _o14 { + _o11 = int(_o1) + } else { + _o15 = _o12 == + 2 + if _o15 { + _o11 = int(_o2) + } else { + _o16 = _o12 == + 3 + if _o16 { + _o11 = int(_o3) + } else { + _o17 = _o12 == + 4 + if _o17 { + _o11 = int(_o4) + } else { + _o18 = _o12 == + 5 + if _o18 { + _o11 = int(_o5) + } else { + _o19 = _o12 == + 6 + if _o19 { + _o11 = int(_o6) + } else { + _o20 = _o12 == + 7 + if _o20 { + _o11 = int(_o7) + } else { + _o21 = _o12 == + 8 + if _o21 { + _o11 = int(_o8) + } else { + _o22 = _o12 == + 9 + if _o22 { + _o11 = int(_o9) + } + } + } + } + } + } + } + } + } + } + } + } + } + _f.IP = 35 + fallthrough + case _f.IP < 36: + + coroutine.Yield[int, any](_o11) + } + } + } + } +} + func Select(n int) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() @@ -2755,6 +3549,7 @@ func init() { serde.RegisterType[*[]uint64]() serde.RegisterType[*bool]() serde.RegisterType[*byte]() + serde.RegisterType[*int]() serde.RegisterType[*int32]() serde.RegisterType[*int64]() serde.RegisterType[*string]() diff --git a/compiler/testdata/coroutine_functypes.go b/compiler/testdata/coroutine_functypes.go index fb76211..61b868a 100644 --- a/compiler/testdata/coroutine_functypes.go +++ b/compiler/testdata/coroutine_functypes.go @@ -15,12 +15,38 @@ func init() { _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") _types.RegisterFunc[func(int, func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range") - _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10Closure") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers") + _types.RegisterClosure[func() bool, struct { + _ uintptr + p **int + q **int + }]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers.func2") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingValues") _types.RegisterClosure[func() bool, struct { _ uintptr i *int n *int - }]("github.com/stealthrocket/coroutine/compiler/testdata.Range10Closure.func2") + }]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingValues.func2") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureHeterogenousCapture") + _types.RegisterClosure[func() int, struct { + _ uintptr + i *uintptr + }]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureHeterogenousCapture.func2") + _types.RegisterClosure[func() bool, struct { + _ uintptr + n *int + a *int8 + b *int16 + c *int32 + d *int64 + e *uint8 + f *uint16 + g *uint32 + h *uint64 + i *uintptr + j *func() int + }]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureHeterogenousCapture.func3") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10Heterogenous") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeArrayIndexValueGenerator") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeOverMaps") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeSliceIndexGenerator") diff --git a/compiler/types.go b/compiler/types.go index a5910e8..d2e5c1a 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -5,6 +5,7 @@ import ( "go/ast" "go/token" "go/types" + "slices" "strconv" ) @@ -96,3 +97,23 @@ func newFields(tuple *types.Tuple) []*ast.Field { } return fields } + +func funcTypeWithNamedResults(t *ast.FuncType) *ast.FuncType { + if t.Results == nil { + return t + } + funcType := *t + funcType.Results = &ast.FieldList{ + List: slices.Clone(t.Results.List), + } + for i, f := range t.Results.List { + if len(f.Names) == 0 { + field := *f + field.Names = []*ast.Ident{ + ast.NewIdent("_"), + } + funcType.Results.List[i] = &field + } + } + return &funcType +} diff --git a/coroutine_durable.go b/coroutine_durable.go index 1f6cf37..1310979 100644 --- a/coroutine_durable.go +++ b/coroutine_durable.go @@ -2,7 +2,9 @@ package coroutine -import "github.com/stealthrocket/coroutine/internal/serde" +import ( + "github.com/stealthrocket/coroutine/internal/serde" +) type serializedCoroutine struct { entry func() diff --git a/internal/serde/reflect.go b/internal/serde/reflect.go index 390e5b5..d1b399a 100644 --- a/internal/serde/reflect.go +++ b/internal/serde/reflect.go @@ -117,6 +117,8 @@ func SerializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) { SerializeUint16(s, *(*uint16)(p)) case reflect.Uint8: SerializeUint8(s, *(*uint8)(p)) + case reflect.Uintptr: + SerializeUintptr(s, *(*uintptr)(p)) case reflect.Float64: SerializeFloat64(s, *(*float64)(p)) case reflect.Float32: @@ -179,6 +181,8 @@ func DeserializeAny(d *Deserializer, t reflect.Type, p unsafe.Pointer) { DeserializeUint16(d, (*uint16)(p)) case reflect.Uint8: DeserializeUint8(d, (*uint8)(p)) + case reflect.Uintptr: + DeserializeUintptr(d, (*uintptr)(p)) case reflect.Float64: DeserializeFloat64(d, (*float64)(p)) case reflect.Float32: @@ -670,6 +674,16 @@ func DeserializeUint8(d *Deserializer, x *uint8) { d.b = d.b[1:] } +func SerializeUintptr(s *Serializer, x uintptr) { + SerializeUint64(s, uint64(x)) +} + +func DeserializeUintptr(d *Deserializer, x *uintptr) { + u := uint64(0) + DeserializeUint64(d, &u) + *x = uintptr(u) +} + func SerializeFloat32(s *Serializer, x float32) { SerializeUint32(s, math.Float32bits(x)) }