diff --git a/compiler/compile.go b/compiler/compile.go index 2512c8b..1f581d0 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -221,11 +221,14 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr }, }) - colorsByDecl := map[*ast.FuncDecl]*types.Signature{} + colorsByDecl := map[ast.Node]*types.Signature{} for fn, color := range colors { - decl, ok := fn.Syntax().(*ast.FuncDecl) - if !ok { - return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl)", fn, fn.Syntax()) + decl := fn.Syntax() + switch decl.(type) { + case *ast.FuncDecl: + case *ast.FuncLit: + default: + return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl or *ast.FuncLit)", fn, decl) } colorsByDecl[decl] = color } @@ -235,6 +238,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr if !ok { continue } + color, ok := colorsByDecl[decl] if !ok { continue @@ -260,7 +264,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr return err } - functypesFile := generateFunctypes(prog.Package(p.Types)) + functypesFile := generateFunctypes(prog.Package(p.Types), colors) functypesPath := filepath.Join(packageDir, "coroutine_functypes.go") if err := c.writeFile(functypesPath, functypesFile); err != nil { return err diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index b5acfcd..914900e 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -105,6 +105,20 @@ func TestCoroutineYield(t *testing.T) { yields: []int{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, }, + // TODO: desugar function call expressions to enable this test. + // + // { + // name: "range over anonymous function", + // coro: func() { RangeTriple(4) }, + // yields: []int{0, 3, 6, 9}, + // }, + + { + name: "range over anonymous function value", + coro: func() { RangeTripleFuncValue(4) }, + yields: []int{0, 3, 6, 9}, + }, + { name: "select", coro: func() { Select(8) }, diff --git a/compiler/function.go b/compiler/function.go index 5340dbd..497c54c 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -2,6 +2,7 @@ package compiler import ( "cmp" + "fmt" "go/ast" "go/token" "slices" @@ -10,7 +11,7 @@ import ( "golang.org/x/tools/go/ssa" ) -func generateFunctypes(pkg *ssa.Package) *ast.File { +func generateFunctypes(pkg *ssa.Package, colors functionColors) *ast.File { var names = make([]string, 0, len(pkg.Members)) for name := range pkg.Members { names = append(names, name) @@ -18,10 +19,10 @@ func generateFunctypes(pkg *ssa.Package) *ast.File { slices.Sort(names) var init ast.BlockStmt - var path = pkg.Pkg.Path() for _, name := range names { if fn, ok := pkg.Members[name].(*ssa.Function); ok { - generateFunctypesInit(path, &init, fn) + name := pkg.Pkg.Path() + "." + fn.Name() + generateFunctypesInit(pkg, fn, &init, name, colors) } } @@ -49,7 +50,7 @@ func generateFunctypes(pkg *ssa.Package) *ast.File { } } -func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) { +func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockStmt, name string, colors functionColors) { if fn.TypeParams() != nil { return // ignore non-instantiated generic functions } @@ -68,7 +69,7 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) { Args: []ast.Expr{ &ast.BasicLit{ Kind: token.STRING, - Value: strconv.Quote(path + "." + fn.Name()), + Value: strconv.Quote(name), }, }, }, @@ -79,7 +80,24 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) { return cmp.Compare(f1.Name(), f2.Name()) }) - for _, anonFunc := range anonFuncs { - generateFunctypesInit(path, init, anonFunc) + for index, anonFunc := range anonFuncs { + _, colored := colors[anonFunc] + if colored { + // Colored functions (those rewritten into coroutines) have a + // deferred anonymous function injected at the beginning to perform + // stack unwinding, which takes the ".func1" name. + index++ + } + name = anonFuncLinkName(name, index) + generateFunctypesInit(pkg, anonFunc, init, name, colors) } } + +// This function computes the name that the linker gives to anonymous functions, +// using the base name of their parent function and appending ".func". +// +// The function works with multiple levels of nesting as each level adds another +// ".func" suffix, with the index being local to the parent scope. +func anonFuncLinkName(base string, index int) string { + return fmt.Sprintf("%s.func%d", base, index+1) +} diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 2cae21e..71817cb 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -266,6 +266,19 @@ func Double(n int) { coroutine.Yield[int, any](2 * n) } +// func RangeTriple(n int) { +// Range(n, func(i int) { +// coroutine.Yield[int, any](3 * i) +// }) +// } + +func RangeTripleFuncValue(n int) { + f := func(i int) { + coroutine.Yield[int, any](3 * i) + } + Range(n, f) +} + func Select(n int) { select { default: diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index fb45381..29d3dc6 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -1365,6 +1365,35 @@ func Double(n int) { coroutine.Yield[int, any](2 * n) } +func RangeTripleFuncValue(n int) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 func(int) + if _f.IP > 0 { + n = _f.Get(0).(int) + _o0 = _f.Get(1).(func(int)) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, n) + _f.Set(1, _o0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _o0 = func(i int) { + coroutine.Yield[int, any](3 * i) + } + _f.IP = 2 + fallthrough + case _f.IP < 3: + Range(n, _o0) + } +} + func Select(n int) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() diff --git a/compiler/testdata/coroutine_functypes.go b/compiler/testdata/coroutine_functypes.go index ed2760b..439a281 100644 --- a/compiler/testdata/coroutine_functypes.go +++ b/compiler/testdata/coroutine_functypes.go @@ -18,6 +18,8 @@ func init() { _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeArrayIndexValueGenerator") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeOverMaps") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeSliceIndexGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue.func2") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Select") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Shadowing") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator") diff --git a/compiler/unsupported.go b/compiler/unsupported.go index 739c25a..67ec035 100644 --- a/compiler/unsupported.go +++ b/compiler/unsupported.go @@ -8,18 +8,17 @@ import ( ) // unsupported checks a function for unsupported language features. -func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { +func unsupported(decl ast.Node, info *types.Info) (err error) { ast.Inspect(decl, func(node ast.Node) bool { switch nn := node.(type) { case ast.Expr: switch nn.(type) { case *ast.FuncLit: - err = fmt.Errorf("not implemented: func literals") - } - if countFunctionCalls(nn, info) > 1 { - err = fmt.Errorf("not implemented: multiple function calls in an expression") + default: + if countFunctionCalls(nn, info) > 1 { + err = fmt.Errorf("not implemented: multiple function calls in an expression") + } } - case ast.Stmt: switch n := nn.(type) { // Not yet supported: