Skip to content

Commit

Permalink
Extract pass to track function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Sep 24, 2023
1 parent b9c58ed commit 30cc303
Showing 1 changed file with 87 additions and 39 deletions.
126 changes: 87 additions & 39 deletions compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,16 @@ import (
// done automatically by the type checker.
func desugar(p *types.Package, stmt ast.Stmt, info *types.Info) ast.Stmt {
d := desugarer{pkg: p, info: info}

// First pass finds function calls and marks nodes in the tree
// that may yield
d.nodesThatMayYield = findCalls(stmt, info)

// Second pass desugars statements that may yield.
stmt = d.desugar(stmt, nil, nil, nil)

// Unused labels cause a compile error (label X defined and not used)
// so we need a second pass over the tree to delete unused labels.
// so we need a third pass over the tree to delete unused labels.
astutil.Apply(stmt, func(cursor *astutil.Cursor) bool {
if ls, ok := cursor.Node().(*ast.LabeledStmt); ok && d.isUnusedLabel(ls.Label) {
cursor.Replace(ls.Stmt)
Expand All @@ -56,12 +62,13 @@ func desugar(p *types.Package, stmt ast.Stmt, info *types.Info) ast.Stmt {
}

type desugarer struct {
pkg *types.Package
info *types.Info
vars int
labels int
unusedLabels map[*ast.Ident]struct{}
userLabels map[types.Object]*ast.Ident
pkg *types.Package
info *types.Info
vars int
labels int
nodesThatMayYield map[ast.Node]struct{}
unusedLabels map[*ast.Ident]struct{}
userLabels map[types.Object]*ast.Ident
}

func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.Ident) ast.Stmt {
Expand Down Expand Up @@ -132,8 +139,12 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
}
body := &ast.BlockStmt{List: s.Body.List}
if s.Cond != nil {
cond := &ast.UnaryExpr{Op: token.NOT, X: s.Cond}
if d.mayYield(s.Cond) {
d.nodesThatMayYield[cond] = struct{}{}
}
body.List = append([]ast.Stmt{&ast.IfStmt{
Cond: &ast.UnaryExpr{Op: token.NOT, X: s.Cond},
Cond: cond,
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.BranchStmt{Tok: token.BREAK}}},
}}, body.List...)
}
Expand Down Expand Up @@ -454,18 +465,26 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
}
list := make([]ast.Expr, len(c.List))
for i := range list {
value := c.List[i]
if tag != nil {
list[i] = &ast.BinaryExpr{X: tag, Op: token.EQL, Y: c.List[i]}
list[i] = &ast.BinaryExpr{X: tag, Op: token.EQL, Y: value}
if d.mayYield(value) {
d.nodesThatMayYield[list[i]] = struct{}{}
}
} else {
list[i] = c.List[i]
list[i] = value
}
}
tmp := d.newVar(types.Typ[types.Bool])
orExpr := list[0]
list = list[1:]
for len(list) > 0 {
// TODO: balance the tree
orExpr = &ast.BinaryExpr{X: orExpr, Op: token.OR, Y: list[0]}
x, y := orExpr, list[0]
orExpr = &ast.BinaryExpr{X: x, Op: token.OR, Y: y}
if d.mayYield(x) || d.mayYield(y) {
d.nodesThatMayYield[orExpr] = struct{}{}
}
list = list[1:]
}
ifStmt := &ast.IfStmt{
Expand Down Expand Up @@ -633,7 +652,15 @@ func (d *desugarer) flatMap(stmt ast.Stmt) (result []ast.Stmt) {
return
}

func (d *desugarer) mayYield(n ast.Node) (mayYield bool) {
type exprFlags int

const (
// multiExprStmt is set if the expression is part of a statement
// that has more than one nested expression of type ast.Expr.
multiExprStmt exprFlags = 1 << iota
)

func (d *desugarer) mayYield(n ast.Node) bool {
switch n.(type) {
case nil:
return false
Expand All @@ -642,39 +669,15 @@ func (d *desugarer) mayYield(n ast.Node) (mayYield bool) {
case *ast.ArrayType, *ast.ChanType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.StructType:
return false
}
// TODO: use information from the callgraph to determine which of those ast.CallExpr may yield
ast.Inspect(n, func(node ast.Node) bool {
if c, ok := node.(*ast.CallExpr); ok {
switch fn := c.Fun.(type) {
case *ast.Ident:
if obj := d.info.ObjectOf(fn); obj != nil {
if obj == types.Universe.Lookup(fn.Name) {
return true // skip builtin function calls
} else if _, ok := obj.(*types.TypeName); ok {
return true // skip type casts
}
}
}
mayYield = true
return false
}
return true
})
return
_, ok := d.nodesThatMayYield[n]
return ok
}

type exprFlags int

const (
// multiExprStmt is set if the expression is part of a statement
// that has more than one nested expression of type ast.Expr.
multiExprStmt exprFlags = 1 << iota
)

func (d *desugarer) decomposeExpression(expr ast.Expr, flags exprFlags) (ast.Expr, []ast.Stmt) {
if !d.mayYield(expr) {
return expr, nil
}

queue := []ast.Expr{expr}
var tmps []*ast.Ident

Expand Down Expand Up @@ -835,3 +838,48 @@ func isUnderscore(e ast.Expr) bool {
i, ok := e.(*ast.Ident)
return ok && i.Name == "_"
}

// findCalls marks nodes in a tree that are an *ast.CallExpr, or lead to
// an *ast.CallExpr.
func findCalls(tree ast.Node, info *types.Info) map[ast.Node]struct{} {
mayYield := map[ast.Node]struct{}{}
var stack []ast.Node
ast.Inspect(tree, func(node ast.Node) bool {
if node != nil {
stack = append(stack, node)

if c, ok := node.(*ast.CallExpr); ok {
// Exclude some call expressions.
switch fn := c.Fun.(type) {
case *ast.Ident:
if obj := info.ObjectOf(fn); obj != nil {
if obj == types.Universe.Lookup(fn.Name) {
return true // skip builtin function calls
} else if _, ok := obj.(*types.TypeName); ok {
return true // skip type casts
}
}
}

// Mark this node, and all nodes that lead to it.
addNodes:
for i := len(stack) - 1; i >= 0; i-- {
n := stack[i]
switch n.(type) {
case *ast.FuncDecl, *ast.FuncLit:
break addNodes
}
if _, ok := mayYield[n]; ok {
break
}
mayYield[n] = struct{}{}
}
}
} else {
stack = stack[:len(stack)-1]
}
return true
})

return mayYield
}

0 comments on commit 30cc303

Please sign in to comment.