Skip to content

Commit

Permalink
Merge branch 'main' into decompose-expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Sep 20, 2023
2 parents 1335bb3 + 67642d0 commit 69b4e1f
Show file tree
Hide file tree
Showing 10 changed files with 1,116 additions and 76 deletions.
103 changes: 78 additions & 25 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strconv"
"strings"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/callgraph/cha"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
Expand Down Expand Up @@ -232,24 +233,26 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
}
colorsByDecl[decl] = color
}

for _, f := range p.Syntax {
for _, anydecl := range f.Decls {
decl, ok := anydecl.(*ast.FuncDecl)
if !ok {
continue
}

color, ok := colorsByDecl[decl]
if !ok {
continue
}

// Reject certain language features for now.
if err := unsupported(decl, p.TypesInfo); err != nil {
return err
}

gen.Decls = append(gen.Decls, c.compileFunction(p, decl, color))
scope := &scope{
colors: colorsByDecl,
objectIdent: 0,
}
gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color))
}
}

Expand All @@ -273,18 +276,44 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
return nil
}

func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl {
log.Printf("compiling function %s %s", p.Name, fn.Name)
type scope struct {
colors map[ast.Node]*types.Signature
// Index used to generate unique object identifiers within the scope of a
// function.
//
// The field is reset to zero after compiling function declarations because
// we don't need globally unique identifiers for local variables.
objectIdent int
}

func (scope *scope) newObjectIdent() *ast.Ident {
ident := scope.objectIdent
scope.objectIdent++
return ast.NewIdent(fmt.Sprintf("_o%d", ident))
}

func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl {
log.Printf("compiling function %s %s", p.Name, fn.Name)
// Generate the coroutine function. At this stage, use the same name
// as the source function (and require that the caller use build tags
// to disambiguate function calls).
gen := &ast.FuncDecl{
return &ast.FuncDecl{
Name: fn.Name,
Type: fn.Type,
Body: &ast.BlockStmt{},
Type: funcTypeWithNamedResults(fn.Type),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, color),
}
}

func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *types.Signature) *ast.FuncLit {
log.Printf("compiling function literal %s", p.Name)
return &ast.FuncLit{
Type: funcTypeWithNamedResults(fn.Type),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, color),
}
}

func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt {
gen := new(ast.BlockStmt)
ctx := ast.NewIdent("_c")
frame := ast.NewIdent("_f")
fp := ast.NewIdent("_fp")
Expand All @@ -294,7 +323,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color
yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type())

// _c := coroutine.LoadContext[R, S]()
gen.Body.List = append(gen.Body.List, &ast.AssignStmt{
gen.List = append(gen.List, &ast.AssignStmt{
Lhs: []ast.Expr{ctx},
Tok: token.DEFINE,
Rhs: []ast.Expr{
Expand All @@ -311,7 +340,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color
})

// _f, _fp := _c.Push()
gen.Body.List = append(gen.Body.List, &ast.AssignStmt{
gen.List = append(gen.List, &ast.AssignStmt{
Lhs: []ast.Expr{frame, fp},
Tok: token.DEFINE,
Rhs: []ast.Expr{
Expand All @@ -321,8 +350,22 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color
},
})

body = astutil.Apply(body,
func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
case *ast.FuncLit:
color, ok := scope.colors[n]
if ok {
cursor.Replace(scope.compileFuncLit(p, n, color))
}
}
return true
},
nil,
).(*ast.BlockStmt)

// Desugar statements in the tree.
fn.Body = desugar(fn.Body, p.TypesInfo).(*ast.BlockStmt)
body = desugar(body, p.TypesInfo).(*ast.BlockStmt)

// Handle declarations.
//
Expand All @@ -337,17 +380,17 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color
// declarations to the function prologue. We downgrade inline var decls and
// assignments that use := to assignments that use =. Constant decls are
// hoisted and also have their value assigned in the function prologue.
decls := extractDecls(fn, p.TypesInfo)
renameObjects(fn, p.TypesInfo, decls)
decls := extractDecls(body, p.TypesInfo)
renameObjects(body, p.TypesInfo, decls, scope)
for _, decl := range decls {
gen.Body.List = append(gen.Body.List, &ast.DeclStmt{Decl: decl})
gen.List = append(gen.List, &ast.DeclStmt{Decl: decl})
}
removeDecls(fn)
removeDecls(body)

// Collect params/results/variables that need to be saved/restored.
var saveAndRestoreNames []*ast.Ident
var saveAndRestoreTypes []types.Type
scanFuncTypeIdentifiers(fn.Type, func(name *ast.Ident) {
scanFuncTypeIdentifiers(typ, func(name *ast.Ident) {
saveAndRestoreNames = append(saveAndRestoreNames, name)
saveAndRestoreTypes = append(saveAndRestoreTypes, p.TypesInfo.TypeOf(name))
})
Expand Down Expand Up @@ -398,7 +441,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color
},
)
}
gen.Body.List = append(gen.Body.List, &ast.IfStmt{
gen.List = append(gen.List, &ast.IfStmt{
Cond: &ast.BinaryExpr{
X: &ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")},
Op: token.GTR, /* > */
Expand All @@ -419,7 +462,7 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color
},
})
}
gen.Body.List = append(gen.Body.List, &ast.DeferStmt{
gen.List = append(gen.List, &ast.DeferStmt{
Call: &ast.CallExpr{
Fun: &ast.FuncLit{
Type: &ast.FuncType{},
Expand Down Expand Up @@ -447,11 +490,21 @@ func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color
},
})

spans := trackDispatchSpans(fn.Body)

compiledBody := compileDispatch(fn.Body, spans).(*ast.BlockStmt)

gen.Body.List = append(gen.Body.List, compiledBody.List...)

spans := trackDispatchSpans(body)
compiledBody := compileDispatch(body, spans).(*ast.BlockStmt)
gen.List = append(gen.List, compiledBody.List...)

// If the function returns one or more values, it must end with a return statement;
// we inject it if the function body does not already has one.
if typ.Results != nil && len(typ.Results.List) > 0 {
needsReturn := len(gen.List) == 0
if !needsReturn {
_, endsWithReturn := gen.List[len(gen.List)-1].(*ast.ReturnStmt)
needsReturn = !endsWithReturn
}
if needsReturn {
gen.List = append(gen.List, &ast.ReturnStmt{})
}
}
return gen
}
22 changes: 20 additions & 2 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,26 @@ func TestCoroutineYield(t *testing.T) {
},

{
name: "range over closure",
coro: func() { Range10Closure() },
name: "range over closure capturing values",
coro: Range10ClosureCapturingValues,
yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
},

{
name: "range over closure capturing pointers",
coro: Range10ClosureCapturingPointers,
yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
},

{
name: "range over closure capturing heterogenous values",
coro: Range10ClosureHeterogenousCapture,
yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
},

{
name: "range with heterogenous values",
coro: Range10Heterogenous,
yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
},

Expand Down
20 changes: 11 additions & 9 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"go/ast"
"go/token"
"go/types"
"strconv"

"golang.org/x/tools/go/ast/astutil"
)
Expand All @@ -21,9 +20,13 @@ import (
// Note that declarations are extracted from all nested scopes within the
// function body, so there may be duplicate identifiers. Identifiers can be
// disambiguated using (*types.Info).ObjectOf(ident).
func extractDecls(fn *ast.FuncDecl, info *types.Info) (decls []*ast.GenDecl) {
ast.Inspect(fn.Body, func(node ast.Node) bool {
func extractDecls(tree ast.Node, info *types.Info) (decls []*ast.GenDecl) {
ast.Inspect(tree, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.FuncLit:
// Stop when we encounter a function listeral so we don't hoist its
// local variables into the scope of its parent function.
return false
case *ast.GenDecl: // const, var, type
if n.Tok == token.TYPE || n.Tok == token.CONST {
decls = append(decls, n)
Expand Down Expand Up @@ -91,9 +94,8 @@ func extractDecls(fn *ast.FuncDecl, info *types.Info) (decls []*ast.GenDecl) {
// renameObjects renames types, constants and variables declared within
// a function. Each is given a unique name, so that declarations are safe
// to hoist into the function prologue.
func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl) {
func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, scope *scope) {
// Scan decls to find objects, giving each new object a unique name.
var id int
newNames := map[types.Object]*ast.Ident{}
for _, decl := range decls {
for _, spec := range decl.Specs {
Expand All @@ -102,15 +104,13 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl) {
if s.Name.Name == "_" {
continue
}
newNames[info.ObjectOf(s.Name)] = ast.NewIdent("_o" + strconv.Itoa(id))
id++
newNames[info.ObjectOf(s.Name)] = scope.newObjectIdent()
case *ast.ValueSpec: // const/var
for _, name := range s.Names {
if name.Name == "_" {
continue
}
newNames[info.ObjectOf(name)] = ast.NewIdent("_o" + strconv.Itoa(id))
id++
newNames[info.ObjectOf(name)] = scope.newObjectIdent()
}
}
}
Expand All @@ -135,6 +135,8 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl) {
func removeDecls(tree ast.Node) {
astutil.Apply(tree, func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
case *ast.FuncLit:
return false
case *ast.AssignStmt:
if n.Tok == token.DEFINE {
if _, ok := cursor.Parent().(*ast.TypeSwitchStmt); ok {
Expand Down
25 changes: 14 additions & 11 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,19 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
return cmp.Compare(f1.Name(), f2.Name())
})

for index, anonFunc := range anonFuncs {
_, colored := colors[anonFunc]
if colored {
// Colored functions (those rewritten into coroutines) have a
// deferred anonymous function injected at the beginning to perform
// stack unwinding, which takes the ".func1" name.
index++
}
name = anonFuncLinkName(name, index)
generateFunctypesInit(pkg, anonFunc, init, name, colors)
index := 0
// Colored functions (those rewritten into coroutines) have a
// deferred anonymous function injected at the beginning to perform
// stack unwinding, which takes the ".func1" name.
_, colored := colors[fn]
if colored {
index++
}

for _, anonFunc := range anonFuncs {
index++
anonFuncName := anonFuncLinkName(name, index)
generateFunctypesInit(pkg, anonFunc, init, anonFuncName, colors)
}
}

Expand All @@ -133,5 +136,5 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
// The function works with multiple levels of nesting as each level adds another
// ".func<index>" suffix, with the index being local to the parent scope.
func anonFuncLinkName(base string, index int) string {
return fmt.Sprintf("%s.func%d", base, index+1)
return fmt.Sprintf("%s.func%d", base, index)
}
Loading

0 comments on commit 69b4e1f

Please sign in to comment.