diff --git a/compiler/Makefile b/compiler/Makefile index 9fb85e8..d557f01 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -1,6 +1,8 @@ GO ?= go -TARGET = testdata/coroutine_durable.go testdata/coroutine_functypes.go +TARGET = \ + testdata/coroutine_generated.go \ + testdata/testdata_generated.go test: clean generate $(GO) test ./... diff --git a/compiler/cmd/coroc/main.go b/compiler/cmd/coroc/main.go index eb4fd65..2f9bf40 100644 --- a/compiler/cmd/coroc/main.go +++ b/compiler/cmd/coroc/main.go @@ -17,8 +17,6 @@ USAGE: OPTIONS: --output Name of the Go file to generate in each package - --tags Build tags to set on generated files - -h, --help Show this help information ` @@ -30,8 +28,6 @@ func main() { } func run() error { - buildTags := flag.String("tags", "", "") - flag.Usage = func() { println(usage[1:]) } flag.Parse() @@ -49,10 +45,5 @@ func run() error { } } - var options []compiler.Option - if *buildTags != "" { - options = append(options, compiler.WithBuildTags(*buildTags)) - } - - return compiler.Compile(path, options...) + return compiler.Compile(path) } diff --git a/compiler/comment.go b/compiler/comment.go index 9806b92..6d88053 100644 --- a/compiler/comment.go +++ b/compiler/comment.go @@ -1,6 +1,8 @@ package compiler -import "go/ast" +import ( + "go/ast" +) func appendCommentGroup(comments []*ast.Comment, group *ast.CommentGroup) []*ast.Comment { if group != nil && len(group.List) > 0 { @@ -19,3 +21,12 @@ func appendComment(comments []*ast.Comment, text string) []*ast.Comment { Text: text, }) } + +func commentGroupsOf(file *ast.File) []*ast.CommentGroup { + groups := make([]*ast.CommentGroup, 0, 1+len(file.Comments)) + groups = append(groups, file.Comments...) + if file.Doc != nil { + groups = append(groups, file.Doc) + } + return groups +} diff --git a/compiler/compile.go b/compiler/compile.go index dae8783..4ecfcb9 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -3,6 +3,7 @@ package compiler import ( "fmt" "go/ast" + "go/build/constraint" "go/format" "go/token" "go/types" @@ -44,15 +45,7 @@ func Compile(path string, options ...Option) error { // Option configures the compiler. type Option func(*compiler) -// WithBuildTags instructs the compiler to attach the specified build -// tags to generated files. -func WithBuildTags(buildTags string) Option { - return func(c *compiler) { c.buildTags = buildTags } -} - type compiler struct { - buildTags string - fset *token.FileSet } @@ -192,22 +185,29 @@ func (c *compiler) compile(path string) error { return nil } -func (c *compiler) writeFile(path string, file *ast.File) error { - f, err := os.Create(path) +func (c *compiler) writeFile(path string, file *ast.File, changeBuildTags func(constraint.Expr) constraint.Expr) error { + buildTags, err := parseBuildTags(file) if err != nil { return err } - defer f.Close() + buildTags = changeBuildTags(buildTags) + stripBuildTagsOf(file, path) + // Comments are awkward to attach to the tree (they rely on token.Pos, which // is coupled to a token.FileSet). Instead, just write out the raw strings. var b strings.Builder - b.WriteString(`// Code generated by coroc. DO NOT EDIT`) - b.WriteString("\n\n") - if c.buildTags != "" { + if buildTags != nil { b.WriteString(`//go:build `) - b.WriteString(c.buildTags) + b.WriteString(buildTags.String()) b.WriteString("\n\n") } + + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + if _, err := f.WriteString(b.String()); err != nil { return err } @@ -233,53 +233,60 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er colorsByFunc[decl] = color } + buildTag := &constraint.TagExpr{ + Tag: "durable", + } + for i, f := range p.Syntax { + if err := c.writeFile(p.GoFiles[i], f, func(expr constraint.Expr) constraint.Expr { + return withoutBuildTag(expr, buildTag) + }); err != nil { + return err + } + // Generate the coroutine AST. gen := &ast.File{ Name: ast.NewIdent(p.Name), } for _, anydecl := range f.Decls { - decl, ok := anydecl.(*ast.FuncDecl) - if !ok { - continue - } - color, ok := colorsByFunc[decl] - if !ok { - gen.Decls = append(gen.Decls, decl) - continue - } - // Reject certain language features for now. - if err := unsupported(decl, p.TypesInfo); err != nil { - return err - } + switch decl := anydecl.(type) { + case *ast.GenDecl: + // Imports get re-added by addImports below, so no need to carry + // them from declarations in the input file. + if decl.Tok != token.IMPORT { + gen.Decls = append(gen.Decls, decl) + continue + } - scope := &scope{colors: colorsByFunc} - // If the function has a single expression it does not contain a - // deferred closure; it won't be added to the list of colored - // functions so generateFunctypes does not mistakenly increment the - // local symbol counter when generating closure names. - gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color)) - } + case *ast.FuncDecl: + color, ok := colorsByFunc[decl] + if !ok { + gen.Decls = append(gen.Decls, decl) + continue + } + // Reject certain language features for now. + if err := unsupported(decl, p.TypesInfo); err != nil { + return err + } - if len(gen.Decls) == 0 { - continue + scope := &scope{colors: colorsByFunc} + gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color)) + } } clearPos(gen) - generateFunctypes(p, gen, colorsByFunc) // Find all the required imports for this file. gen = addImports(p, gen) - outputPath, _ := strings.CutSuffix(p.GoFiles[i], ".go") - if c.buildTags != "" { - outputPath += "_" + strings.ReplaceAll(c.buildTags, ",", "_") + ".go" - } else { - outputPath += "_generated.go" - } - if err := c.writeFile(outputPath, gen); err != nil { + outputPath := strings.TrimSuffix(p.GoFiles[i], ".go") + outputPath += "_generated.go" + + if err := c.writeFile(outputPath, gen, func(expr constraint.Expr) constraint.Expr { + return withBuildTag(expr, buildTag) + }); err != nil { return err } } diff --git a/compiler/constraint.go b/compiler/constraint.go new file mode 100644 index 0000000..38551e3 --- /dev/null +++ b/compiler/constraint.go @@ -0,0 +1,79 @@ +package compiler + +import ( + "go/ast" + "go/build/constraint" + "reflect" + "slices" +) + +func containsExpr(expr, contains constraint.Expr) bool { + switch x := expr.(type) { + case *constraint.AndExpr: + return containsExpr(x.X, contains) || containsExpr(x.Y, contains) + case *constraint.OrExpr: + return containsExpr(x.X, contains) && containsExpr(x.Y, contains) + default: + return reflect.DeepEqual(expr, contains) + } +} + +func withBuildTag(expr constraint.Expr, buildTag *constraint.TagExpr) constraint.Expr { + if buildTag == nil || containsExpr(expr, buildTag) { + return expr + } else if expr == nil { + return buildTag + } else { + return &constraint.AndExpr{X: expr, Y: buildTag} + } +} + +func withoutBuildTag(expr constraint.Expr, buildTag *constraint.TagExpr) constraint.Expr { + notBuildTag := &constraint.NotExpr{X: buildTag} + if buildTag == nil || containsExpr(expr, notBuildTag) { + return expr + } else if expr == nil { + return notBuildTag + } else { + return &constraint.AndExpr{X: expr, Y: notBuildTag} + } +} + +func parseBuildTags(file *ast.File) (constraint.Expr, error) { + groups := commentGroupsOf(file) + + for _, group := range groups { + for _, c := range group.List { + if constraint.IsGoBuild(c.Text) { + return constraint.Parse(c.Text) + } + } + } + + var plusBuildLines constraint.Expr + for _, group := range groups { + for _, c := range group.List { + if constraint.IsPlusBuild(c.Text) { + x, err := constraint.Parse(c.Text) + if err != nil { + return nil, err + } + if plusBuildLines == nil { + plusBuildLines = x + } else { + plusBuildLines = &constraint.AndExpr{X: plusBuildLines, Y: x} + } + } + } + } + + return plusBuildLines, nil +} + +func stripBuildTagsOf(file *ast.File, path string) { + for _, group := range commentGroupsOf(file) { + group.List = slices.DeleteFunc(group.List, func(c *ast.Comment) bool { + return constraint.IsGoBuild(c.Text) || constraint.IsPlusBuild(c.Text) + }) + } +} diff --git a/compiler/function.go b/compiler/function.go index bb99e4b..a566224 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -240,14 +240,16 @@ func generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*ty }) } - astutil.AddNamedImport(nil, f, "_types", "github.com/stealthrocket/coroutine/types") - - f.Decls = append(f.Decls, - &ast.FuncDecl{ - Name: ast.NewIdent("init"), - Type: &ast.FuncType{Params: new(ast.FieldList)}, - Body: init, - }) + if len(init.List) > 0 { + astutil.AddNamedImport(nil, f, "_types", "github.com/stealthrocket/coroutine/types") + + f.Decls = append(f.Decls, + &ast.FuncDecl{ + Name: ast.NewIdent("init"), + Type: &ast.FuncType{Params: new(ast.FieldList)}, + Body: init, + }) + } } // This function computes the name that the linker gives to anonymous functions, diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index e5337fc..dfbbded 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -9,7 +9,7 @@ import ( "github.com/stealthrocket/coroutine" ) -//go:generate coroc --tags durable +//go:generate coroc func SomeFunctionThatShouldExistInTheCompiledFile() { } diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_generated.go similarity index 99% rename from compiler/testdata/coroutine_durable.go rename to compiler/testdata/coroutine_generated.go index b2f8496..570d09d 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_generated.go @@ -1,22 +1,22 @@ -// Code generated by coroc. DO NOT EDIT - //go:build durable package testdata import ( coroutine "github.com/stealthrocket/coroutine" - unsafe "unsafe" time "time" + unsafe "unsafe" ) import _types "github.com/stealthrocket/coroutine/types" func SomeFunctionThatShouldExistInTheCompiledFile() { } + //go:noinline func Identity(n int) { coroutine.Yield[int, any](n) } + //go:noinline func SquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() @@ -55,6 +55,7 @@ func SquareGenerator(n int) { } } } + //go:noinline func SquareGeneratorTwice(n int) { _c := coroutine.LoadContext[int, any]() @@ -88,6 +89,7 @@ func SquareGeneratorTwice(n int) { SquareGenerator(_f0.X0) } } + //go:noinline func SquareGeneratorTwiceLoop(n int) { _c := coroutine.LoadContext[int, any]() @@ -126,6 +128,7 @@ func SquareGeneratorTwiceLoop(n int) { } } } + //go:noinline func EvenSquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() @@ -176,6 +179,7 @@ func EvenSquareGenerator(n int) { } } } + //go:noinline func NestedLoops(n int) { _c := coroutine.LoadContext[int, any]() @@ -238,6 +242,7 @@ func NestedLoops(n int) { } } } + //go:noinline func FizzBuzzIfGenerator(n int) { _c := coroutine.LoadContext[int, any]() @@ -298,6 +303,7 @@ func FizzBuzzIfGenerator(n int) { } } } + //go:noinline func FizzBuzzSwitchGenerator(n int) { _c := coroutine.LoadContext[int, any]() @@ -381,6 +387,7 @@ func FizzBuzzSwitchGenerator(n int) { } } } + //go:noinline func Shadowing(_ int) { _c := coroutine.LoadContext[int, any]() @@ -708,6 +715,7 @@ func Shadowing(_ int) { coroutine.Yield[int, any](_f0.X22) } } + //go:noinline func RangeSliceIndexGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -753,6 +761,7 @@ func RangeSliceIndexGenerator(_ int) { } } } + //go:noinline func RangeArrayIndexValueGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -812,6 +821,7 @@ func RangeArrayIndexValueGenerator(_ int) { } } } + //go:noinline func TypeSwitchingGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -889,6 +899,7 @@ func TypeSwitchingGenerator(_ int) { } } } + //go:noinline func LoopBreakAndContinue(_ int) { _c := coroutine.LoadContext[int, any]() @@ -1036,6 +1047,7 @@ func LoopBreakAndContinue(_ int) { } } } + //go:noinline func RangeOverMaps(n int) { _c := coroutine.LoadContext[int, any]() @@ -1371,6 +1383,7 @@ func RangeOverMaps(n int) { } } } + //go:noinline func Range(n int, do func(int)) { _c := coroutine.LoadContext[int, any]() @@ -1412,16 +1425,19 @@ func Range(n int, do func(int)) { } } } + //go:noinline func Double(n int) { coroutine.Yield[int, any](2 * n) } + //go:noinline func RangeTriple(n int) { Range(n, func(i int) { coroutine.Yield[int, any](3 * i) }) } + //go:noinline func RangeTripleFuncValue(n int) { _c := coroutine.LoadContext[int, any]() @@ -1460,6 +1476,7 @@ func RangeTripleFuncValue(n int) { Range(_f0.X0, _f0.X1) } } + //go:noinline func RangeReverseClosureCaptureByValue(n int) { _c := coroutine.LoadContext[int, any]() @@ -1514,6 +1531,7 @@ func RangeReverseClosureCaptureByValue(n int) { } } } + //go:noinline func Range10ClosureCapturingValues() { _c := coroutine.LoadContext[int, any]() @@ -1625,6 +1643,7 @@ func Range10ClosureCapturingValues() { } } } + //go:noinline func Range10ClosureCapturingPointers() { _c := coroutine.LoadContext[int, any]() @@ -1746,6 +1765,7 @@ func Range10ClosureCapturingPointers() { } } } + //go:noinline func Range10ClosureHeterogenousCapture() { _c := coroutine.LoadContext[int, any]() @@ -1961,6 +1981,7 @@ func Range10ClosureHeterogenousCapture() { } } } + //go:noinline func Range10Heterogenous() { _c := coroutine.LoadContext[int, any]() @@ -2074,6 +2095,7 @@ func Range10Heterogenous() { } } } + //go:noinline func Select(n int) { _c := coroutine.LoadContext[int, any]() @@ -2365,6 +2387,7 @@ func Select(n int) { } } } + //go:noinline func YieldingExpressionDesugaring() { _c := coroutine.LoadContext[int, any]() @@ -2790,6 +2813,7 @@ func YieldingExpressionDesugaring() { } } } + //go:noinline func a(v int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -2824,6 +2848,7 @@ func a(v int) (_ int) { } return } + //go:noinline func b(v int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -2858,6 +2883,7 @@ func b(v int) (_ int) { } return } + //go:noinline func YieldingDurations() { _c := coroutine.LoadContext[int, any]() @@ -2970,6 +2996,7 @@ func YieldingDurations() { } } } + //go:noinline func YieldAndDeferAssign(assign *int, yield, value int) { _c := coroutine.LoadContext[int, any]() @@ -3017,6 +3044,7 @@ func YieldAndDeferAssign(assign *int, yield, value int) { coroutine.Yield[int, any](_f0.X1) } } + //go:noinline func RangeYieldAndDeferAssign(n int) { _c := coroutine.LoadContext[int, any]() diff --git a/compiler/testdata/http/main.go b/compiler/testdata/http/main.go index 2987b88..b6996a9 100644 --- a/compiler/testdata/http/main.go +++ b/compiler/testdata/http/main.go @@ -1,3 +1,5 @@ +//go:build !durable + package main import ( diff --git a/compiler/testdata/http/main_generated.go b/compiler/testdata/http/main_generated.go new file mode 100644 index 0000000..2d5b1c7 --- /dev/null +++ b/compiler/testdata/http/main_generated.go @@ -0,0 +1,106 @@ +//go:build durable + +package main + +import ( + http "net/http" + coroutine "github.com/stealthrocket/coroutine" + fmt "fmt" +) +import _types "github.com/stealthrocket/coroutine/types" + +type yieldingRoundTripper struct { +} +//go:noinline +func RoundTrip(req *http.Request) (_ *http.Response, _ error) { + _c := coroutine.LoadContext[*http.Request, *http.Response]() + _f, _fp := _c.Push() + var _f0 *struct { + X0 *http.Request + X1 *http.Response + } + if _f.IP == 0 { + _f0 = &struct { + X0 *http.Request + X1 *http.Response + }{X0: req} + } else { + _f0 = _f.Get(0).(*struct { + X0 *http.Request + X1 *http.Response + }) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _f0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _f0.X1 = coroutine.Yield[*http.Request, *http.Response](_f0.X0) + _f.IP = 2 + fallthrough + case _f.IP < 3: + return _f0.X1, nil + } + return +} +//go:noinline +func work() { + _c := coroutine.LoadContext[*http.Request, *http.Response]() + _f, _fp := _c.Push() + var _f0 *struct { + X0 *http.Response + X1 error + } + if _f.IP == 0 { + _f0 = &struct { + X0 *http.Response + X1 error + }{} + } else { + _f0 = _f.Get(0).(*struct { + X0 *http.Response + X1 error + }) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _f0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _f0.X0, _f0.X1 = http.Get("http://example.com") + _f.IP = 2 + fallthrough + case _f.IP < 3: + if _f0.X1 != nil { + panic(_f0.X1) + } + _f.IP = 3 + fallthrough + case _f.IP < 4: + fmt.Println(_f0.X0.StatusCode) + } +} +func main() { + http.DefaultTransport = &yieldingRoundTripper{} + c := coroutine.New[*http.Request, *http.Response](work) + for c.Next() { + req := c.Recv() + fmt.Println("Requesting", req.URL.String()) + c.Send(&http.Response{StatusCode: 200}) + } +} +func init() { + _types.RegisterFunc[func(req *http.Request) (_ *http.Response, _ error)]("github.com/stealthrocket/coroutine/compiler/testdata/http.RoundTrip") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata/http.main") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata/http.work") +} diff --git a/compiler/testdata/testdata.go b/compiler/testdata/testdata.go index f718312..63dd7c4 100644 --- a/compiler/testdata/testdata.go +++ b/compiler/testdata/testdata.go @@ -1,3 +1,5 @@ +//go:build !durable + package testdata const ( diff --git a/compiler/testdata/testdata_generated.go b/compiler/testdata/testdata_generated.go new file mode 100644 index 0000000..0877781 --- /dev/null +++ b/compiler/testdata/testdata_generated.go @@ -0,0 +1,9 @@ +//go:build durable + +package testdata + +const ( + Fizz = -1 + Buzz = -2 + FizzBuzz = -3 +)