Skip to content

Commit

Permalink
Automatically import the coroutine package when used
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Sep 25, 2023
1 parent f4ad33e commit dbb0561
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ func Compile(path string, options ...Option) error {
type Option func(*compiler)

type compiler struct {
coroutinePkg *packages.Package

fset *token.FileSet
}

Expand Down Expand Up @@ -116,18 +118,17 @@ func (c *compiler) compile(path string) error {
cg := vta.CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))

log.Printf("finding generic yield instantiations")
var coroutinePkg *packages.Package
packages.Visit(pkgs, func(p *packages.Package) bool {
if p.PkgPath == coroutinePackage {
coroutinePkg = p
c.coroutinePkg = p
}
return coroutinePkg == nil
return c.coroutinePkg == nil
}, nil)
if coroutinePkg == nil {
if c.coroutinePkg == nil {
log.Printf("%s not imported by the module. Nothing to do", coroutinePackage)
return nil
}
yieldFunc := prog.FuncValue(coroutinePkg.Types.Scope().Lookup("Yield").(*types.Func))
yieldFunc := prog.FuncValue(c.coroutinePkg.Types.Scope().Lookup("Yield").(*types.Func))
yieldInstances := functionColors{}
for fn := range ssautil.AllFunctions(prog) {
if fn.Origin() == yieldFunc {
Expand Down Expand Up @@ -270,7 +271,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er
return err
}

scope := &scope{colors: colorsByFunc}
scope := &scope{compiler: c, colors: colorsByFunc}
gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color))
}
}
Expand Down Expand Up @@ -348,6 +349,8 @@ func addImports(p *packages.Package, gen *ast.File) *ast.File {
}

type scope struct {
compiler *compiler

colors map[ast.Node]*types.Signature
// Index used to generate unique object identifiers within the scope of a
// function.
Expand Down Expand Up @@ -473,6 +476,9 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type())
yieldTypeExpr[1] = typeExpr(p, color.Results().At(0).Type())

coroutineIdent := ast.NewIdent("coroutine")
p.TypesInfo.Uses[coroutineIdent] = types.NewPkgName(token.NoPos, p.Types, "coroutine", scope.compiler.coroutinePkg.Types)

// _c := coroutine.LoadContext[R, S]()
gen.List = append(gen.List, &ast.AssignStmt{
Lhs: []ast.Expr{ctx},
Expand All @@ -481,7 +487,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
&ast.CallExpr{
Fun: &ast.IndexListExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent("coroutine"),
X: coroutineIdent,
Sel: ast.NewIdent("LoadContext"),
},
Indices: yieldTypeExpr,
Expand Down

0 comments on commit dbb0561

Please sign in to comment.