Skip to content

Commit

Permalink
simplify function stack frame
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Sep 26, 2023
1 parent 0d727c7 commit f79beaf
Show file tree
Hide file tree
Showing 6 changed files with 1,369 additions and 1,347 deletions.
124 changes: 50 additions & 74 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,6 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body

gen := new(ast.BlockStmt)
ctx := ast.NewIdent("_c")
frame := ast.NewIdent("_f")
fp := ast.NewIdent("_fp")

yieldTypeExpr := make([]ast.Expr, 2)
yieldTypeExpr[0] = typeExpr(p, color.Params().At(0).Type())
Expand All @@ -496,17 +494,6 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
},
})

// _f, _fp := _c.Push()
gen.List = append(gen.List, &ast.AssignStmt{
Lhs: []ast.Expr{frame, fp},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Push")},
},
},
})

frameName := ast.NewIdent(fmt.Sprintf("_f%d", scope.frameIndex))
scope.frameIndex++

Expand All @@ -528,64 +515,68 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
decls, frameType, frameInit := extractDecls(p, typ, body, recv, defers, p.TypesInfo)
renameObjects(body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)

// var _f{n} F = coroutine.Push[F](&_c.Stack)
gen.List = append(gen.List, &ast.DeclStmt{Decl: &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{&ast.ValueSpec{
Names: []*ast.Ident{frameName},
Type: &ast.StarExpr{X: frameType},
Values: []ast.Expr{&ast.CallExpr{
Fun: &ast.IndexListExpr{
X: &ast.SelectorExpr{X: coroutineIdent, Sel: ast.NewIdent("Push")},
Indices: []ast.Expr{frameType},
},
Args: []ast.Expr{&ast.UnaryExpr{
Op: token.AND,
X: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Stack")},
}},
}},
}},
}})

for _, decl := range decls {
gen.List = append(gen.List, &ast.DeclStmt{Decl: decl})
}

gen.List = append(gen.List,
&ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{frameName},
Type: &ast.StarExpr{X: frameType},
},
},
},
},
)

gen.List = append(gen.List, &ast.IfStmt{
Cond: &ast.BinaryExpr{
X: &ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")},
X: &ast.SelectorExpr{X: frameName, Sel: ast.NewIdent("IP")},
Op: token.EQL, /* == */
Y: &ast.BasicLit{Kind: token.INT, Value: "0"}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.AssignStmt{
Tok: token.ASSIGN,
Lhs: []ast.Expr{frameName},
Rhs: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: frameInit}},
}}},
Else: &ast.BlockStmt{List: []ast.Stmt{&ast.AssignStmt{
Lhs: []ast.Expr{frameName},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.TypeAssertExpr{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: frame, Sel: ast.NewIdent("Get")},
Args: []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}},
},
Type: &ast.StarExpr{X: frameType},
}},
Lhs: []ast.Expr{&ast.StarExpr{X: frameName}},
Rhs: []ast.Expr{frameInit},
}}},
})

popFrame := []ast.Stmt{
&ast.ExprStmt{X: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Pop")}}},
popExpr := &ast.CallExpr{
Fun: &ast.SelectorExpr{X: coroutineIdent, Sel: ast.NewIdent("Pop")},
Args: []ast.Expr{&ast.UnaryExpr{
Op: token.AND,
X: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Stack")},
}},
}

if defers != nil {
popFrame = append(popFrame, &ast.RangeStmt{
Key: ast.NewIdent("_"),
Value: ast.NewIdent("f"),
Tok: token.DEFINE,
X: &ast.SelectorExpr{
X: frameName,
Sel: frameType.Fields.List[len(frameType.Fields.List)-1].Names[0],
var popFrame []ast.Stmt
if defers == nil {
popFrame = []ast.Stmt{&ast.ExprStmt{X: popExpr}}
} else {
popFrame = []ast.Stmt{
&ast.DeferStmt{Call: popExpr},
&ast.RangeStmt{
Key: ast.NewIdent("_"),
Value: ast.NewIdent("f"),
Tok: token.DEFINE,
X: &ast.SelectorExpr{
X: frameName,
Sel: frameType.Fields.List[len(frameType.Fields.List)-1].Names[0],
},
Body: &ast.BlockStmt{List: []ast.Stmt{
&ast.DeferStmt{Call: &ast.CallExpr{Fun: ast.NewIdent("f")}},
}},
},
Body: &ast.BlockStmt{List: []ast.Stmt{
&ast.DeferStmt{Call: &ast.CallExpr{Fun: ast.NewIdent("f")}},
}},
})
}
}

gen.List = append(gen.List, &ast.DeferStmt{
Expand All @@ -595,25 +586,10 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.IfStmt{
Cond: &ast.CallExpr{
Cond: &ast.UnaryExpr{Op: token.NOT, X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Unwinding")},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: frame, Sel: ast.NewIdent("Set")},
Args: []ast.Expr{
&ast.BasicLit{Kind: token.INT, Value: "0"},
frameName,
},
}},
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ctx, Sel: ast.NewIdent("Store")},
Args: []ast.Expr{fp, frame},
}},
},
},
Else: &ast.BlockStmt{List: popFrame},
}},
Body: &ast.BlockStmt{List: popFrame},
},
},
},
Expand All @@ -623,7 +599,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body

spans := trackDispatchSpans(body)
mayYield = findCalls(body, p.TypesInfo)
compiledBody := compileDispatch(body, spans, mayYield).(*ast.BlockStmt)
compiledBody := compileDispatch(body, frameName, spans, mayYield).(*ast.BlockStmt)
gen.List = append(gen.List, compiledBody.List...)

// If the function returns one or more values, it must end with a return statement;
Expand Down
11 changes: 8 additions & 3 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ import (
// function body, so there may be duplicate identifiers. Identifiers can be
// disambiguated using (*types.Info).ObjectOf(ident).
func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, recv *ast.FieldList, defers *ast.Ident, info *types.Info) (decls []*ast.GenDecl, frameType *ast.StructType, frameInit *ast.CompositeLit) {
frameType = &ast.StructType{Fields: &ast.FieldList{}}
IP := &ast.Field{
Names: []*ast.Ident{ast.NewIdent("IP")},
Type: ast.NewIdent("int"),
}

frameType = &ast.StructType{Fields: &ast.FieldList{List: []*ast.Field{IP}}}
frameInit = &ast.CompositeLit{Type: frameType}

if recv != nil {
Expand Down Expand Up @@ -184,7 +189,7 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
}

index := 0
for i, field := range frameType.Fields.List {
for i, field := range frameType.Fields.List[1:] {
fieldNames := make([]*ast.Ident, len(field.Names))

for j, ident := range field.Names {
Expand All @@ -206,7 +211,7 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
}
}

frameType.Fields.List[i] = &ast.Field{
frameType.Fields.List[i+1] = &ast.Field{
Names: fieldNames,
Type: field.Type,
}
Expand Down
40 changes: 21 additions & 19 deletions compiler/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func trackDispatchSpans0(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan,
// to the correct location in the code, even when there are arbitrary
// levels of branches and loops. To do this, we generate a switch inside
// each block, using the information from trackDispatchSpans.
func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan, mayYield map[ast.Node]struct{}) ast.Stmt {
func compileDispatch(stmt ast.Stmt, frame *ast.Ident, dispatchSpans map[ast.Stmt]dispatchSpan, mayYield map[ast.Node]struct{}) ast.Stmt {
if _, ok := mayYield[stmt]; !ok {
return stmt
}
Expand All @@ -77,19 +77,21 @@ func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan, may
case *ast.BlockStmt:
switch {
case len(s.List) == 1:
child := compileDispatch(s.List[0], dispatchSpans, mayYield)
child := compileDispatch(s.List[0], frame, dispatchSpans, mayYield)
s.List[0] = unnestBlocks(child)
case len(s.List) > 1:
stmt = &ast.BlockStmt{List: []ast.Stmt{compileDispatch0(s.List, dispatchSpans, mayYield)}}
stmt = &ast.BlockStmt{
List: []ast.Stmt{compileDispatch0(s.List, frame, dispatchSpans, mayYield)},
}
}
case *ast.IfStmt:
s.Body = compileDispatch(s.Body, dispatchSpans, mayYield).(*ast.BlockStmt)
s.Body = compileDispatch(s.Body, frame, dispatchSpans, mayYield).(*ast.BlockStmt)
if s.Else != nil {
s.Else = compileDispatch(s.Else, dispatchSpans, mayYield)
s.Else = compileDispatch(s.Else, frame, dispatchSpans, mayYield)
}
case *ast.ForStmt:
forSpan := dispatchSpans[s]
s.Body = compileDispatch(s.Body, dispatchSpans, mayYield).(*ast.BlockStmt)
s.Body = compileDispatch(s.Body, frame, dispatchSpans, mayYield).(*ast.BlockStmt)

// Hijack the loop's post iteration statement to inject an IP reset.
if s.Post == nil {
Expand Down Expand Up @@ -147,54 +149,54 @@ func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan, may
}
assign.Tok = token.ASSIGN
}
assign.Lhs = append(assign.Lhs, &ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")})
assign.Lhs = append(assign.Lhs, &ast.SelectorExpr{X: frame, Sel: ast.NewIdent("IP")})
assign.Rhs = append(assign.Rhs, &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(forSpan.start)})

case *ast.SwitchStmt:
for i, child := range s.Body.List {
s.Body.List[i] = compileDispatch(child, dispatchSpans, mayYield)
s.Body.List[i] = compileDispatch(child, frame, dispatchSpans, mayYield)
}
case *ast.TypeSwitchStmt:
for i, child := range s.Body.List {
s.Body.List[i] = compileDispatch(child, dispatchSpans, mayYield)
s.Body.List[i] = compileDispatch(child, frame, dispatchSpans, mayYield)
}
case *ast.SelectStmt:
for i, child := range s.Body.List {
s.Body.List[i] = compileDispatch(child, dispatchSpans, mayYield)
s.Body.List[i] = compileDispatch(child, frame, dispatchSpans, mayYield)
}
case *ast.CaseClause:
switch {
case len(s.Body) == 1:
child := compileDispatch(s.Body[0], dispatchSpans, mayYield)
child := compileDispatch(s.Body[0], frame, dispatchSpans, mayYield)
s.Body[0] = unnestBlocks(child)
case len(s.Body) > 1:
s.Body = []ast.Stmt{compileDispatch0(s.Body, dispatchSpans, mayYield)}
s.Body = []ast.Stmt{compileDispatch0(s.Body, frame, dispatchSpans, mayYield)}
}
case *ast.CommClause:
switch {
case len(s.Body) == 1:
child := compileDispatch(s.Body[0], dispatchSpans, mayYield)
child := compileDispatch(s.Body[0], frame, dispatchSpans, mayYield)
s.Body[0] = unnestBlocks(child)
case len(s.Body) > 1:
s.Body = []ast.Stmt{compileDispatch0(s.Body, dispatchSpans, mayYield)}
s.Body = []ast.Stmt{compileDispatch0(s.Body, frame, dispatchSpans, mayYield)}
}
case *ast.LabeledStmt:
s.Stmt = compileDispatch(s.Stmt, dispatchSpans, mayYield)
s.Stmt = compileDispatch(s.Stmt, frame, dispatchSpans, mayYield)
}
return stmt
}

func compileDispatch0(stmts []ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan, mayYield map[ast.Node]struct{}) ast.Stmt {
func compileDispatch0(stmts []ast.Stmt, frame *ast.Ident, dispatchSpans map[ast.Stmt]dispatchSpan, mayYield map[ast.Node]struct{}) ast.Stmt {
var cases []ast.Stmt
for i, child := range stmts {
childSpan := dispatchSpans[child]
compiledChild := compileDispatch(child, dispatchSpans, mayYield)
compiledChild := compileDispatch(child, frame, dispatchSpans, mayYield)
compiledChild = unnestBlocks(compiledChild)
caseBody := []ast.Stmt{compiledChild}
if i < len(stmts)-1 {
caseBody = append(caseBody,
&ast.AssignStmt{
Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")}},
Lhs: []ast.Expr{&ast.SelectorExpr{X: frame, Sel: ast.NewIdent("IP")}},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(childSpan.end)}},
},
Expand All @@ -203,7 +205,7 @@ func compileDispatch0(stmts []ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan,
cases = append(cases, &ast.CaseClause{
List: []ast.Expr{
&ast.BinaryExpr{
X: &ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")},
X: &ast.SelectorExpr{X: frame, Sel: ast.NewIdent("IP")},
Op: token.LSS, /* < */
Y: &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(childSpan.end)}},
},
Expand Down
31 changes: 20 additions & 11 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,16 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
}
}

pre := func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
var inspect func(ast.Node) bool
inspect = func(node ast.Node) bool {
switch n := node.(type) {
case *ast.Ident:
observeIdent(n)

case *ast.SelectorExpr:
ast.Inspect(n.X, inspect)
return false

case *ast.GenDecl:
if n.Tok == token.VAR {
for _, spec := range n.Specs {
Expand All @@ -110,6 +115,7 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
}
}
}
return false
}

case *ast.FuncLit:
Expand All @@ -125,15 +131,18 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
return true
}

post := func(cursor *astutil.Cursor) bool {
switch cursor.Node().(type) {
case *ast.BlockStmt:
scope = scope.outer
}
return true
}

astutil.Apply(functionBodyOf(fn), pre, post)
astutil.Apply(functionBodyOf(fn),
func(cursor *astutil.Cursor) bool {
return inspect(cursor.Node())
},
func(cursor *astutil.Cursor) bool {
switch cursor.Node().(type) {
case *ast.BlockStmt:
scope = scope.outer
}
return true
},
)

functype := functype{
signature: signature,
Expand Down
Loading

0 comments on commit f79beaf

Please sign in to comment.