diff --git a/compiler/desugar.go b/compiler/desugar.go index fe1c863..90a7f3f 100644 --- a/compiler/desugar.go +++ b/compiler/desugar.go @@ -41,10 +41,16 @@ import ( // done automatically by the type checker. func desugar(p *types.Package, stmt ast.Stmt, info *types.Info) ast.Stmt { d := desugarer{pkg: p, info: info} + + // First pass finds function calls and marks nodes in the tree + // that may yield + d.nodesThatMayYield = findCalls(stmt, info) + + // Second pass desugars statements that may yield. stmt = d.desugar(stmt, nil, nil, nil) // Unused labels cause a compile error (label X defined and not used) - // so we need a second pass over the tree to delete unused labels. + // so we need a third pass over the tree to delete unused labels. astutil.Apply(stmt, func(cursor *astutil.Cursor) bool { if ls, ok := cursor.Node().(*ast.LabeledStmt); ok && d.isUnusedLabel(ls.Label) { cursor.Replace(ls.Stmt) @@ -56,12 +62,13 @@ func desugar(p *types.Package, stmt ast.Stmt, info *types.Info) ast.Stmt { } type desugarer struct { - pkg *types.Package - info *types.Info - vars int - labels int - unusedLabels map[*ast.Ident]struct{} - userLabels map[types.Object]*ast.Ident + pkg *types.Package + info *types.Info + vars int + labels int + nodesThatMayYield map[ast.Node]struct{} + unusedLabels map[*ast.Ident]struct{} + userLabels map[types.Object]*ast.Ident } func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.Ident) ast.Stmt { @@ -132,8 +139,12 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I } body := &ast.BlockStmt{List: s.Body.List} if s.Cond != nil { + cond := &ast.UnaryExpr{Op: token.NOT, X: s.Cond} + if d.mayYield(s.Cond) { + d.nodesThatMayYield[cond] = struct{}{} + } body.List = append([]ast.Stmt{&ast.IfStmt{ - Cond: &ast.UnaryExpr{Op: token.NOT, X: s.Cond}, + Cond: cond, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.BranchStmt{Tok: token.BREAK}}}, }}, body.List...) } @@ -454,10 +465,14 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I } list := make([]ast.Expr, len(c.List)) for i := range list { + value := c.List[i] if tag != nil { - list[i] = &ast.BinaryExpr{X: tag, Op: token.EQL, Y: c.List[i]} + list[i] = &ast.BinaryExpr{X: tag, Op: token.EQL, Y: value} + if d.mayYield(value) { + d.nodesThatMayYield[list[i]] = struct{}{} + } } else { - list[i] = c.List[i] + list[i] = value } } tmp := d.newVar(types.Typ[types.Bool]) @@ -465,7 +480,11 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I list = list[1:] for len(list) > 0 { // TODO: balance the tree - orExpr = &ast.BinaryExpr{X: orExpr, Op: token.OR, Y: list[0]} + x, y := orExpr, list[0] + orExpr = &ast.BinaryExpr{X: x, Op: token.OR, Y: y} + if d.mayYield(x) || d.mayYield(y) { + d.nodesThatMayYield[orExpr] = struct{}{} + } list = list[1:] } ifStmt := &ast.IfStmt{ @@ -633,7 +652,15 @@ func (d *desugarer) flatMap(stmt ast.Stmt) (result []ast.Stmt) { return } -func (d *desugarer) mayYield(n ast.Node) (mayYield bool) { +type exprFlags int + +const ( + // multiExprStmt is set if the expression is part of a statement + // that has more than one nested expression of type ast.Expr. + multiExprStmt exprFlags = 1 << iota +) + +func (d *desugarer) mayYield(n ast.Node) bool { switch n.(type) { case nil: return false @@ -642,39 +669,15 @@ func (d *desugarer) mayYield(n ast.Node) (mayYield bool) { case *ast.ArrayType, *ast.ChanType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.StructType: return false } - // TODO: use information from the callgraph to determine which of those ast.CallExpr may yield - ast.Inspect(n, func(node ast.Node) bool { - if c, ok := node.(*ast.CallExpr); ok { - switch fn := c.Fun.(type) { - case *ast.Ident: - if obj := d.info.ObjectOf(fn); obj != nil { - if obj == types.Universe.Lookup(fn.Name) { - return true // skip builtin function calls - } else if _, ok := obj.(*types.TypeName); ok { - return true // skip type casts - } - } - } - mayYield = true - return false - } - return true - }) - return + _, ok := d.nodesThatMayYield[n] + return ok } -type exprFlags int - -const ( - // multiExprStmt is set if the expression is part of a statement - // that has more than one nested expression of type ast.Expr. - multiExprStmt exprFlags = 1 << iota -) - func (d *desugarer) decomposeExpression(expr ast.Expr, flags exprFlags) (ast.Expr, []ast.Stmt) { if !d.mayYield(expr) { return expr, nil } + queue := []ast.Expr{expr} var tmps []*ast.Ident @@ -835,3 +838,48 @@ func isUnderscore(e ast.Expr) bool { i, ok := e.(*ast.Ident) return ok && i.Name == "_" } + +// findCalls marks nodes in a tree that are an *ast.CallExpr, or lead to +// an *ast.CallExpr. +func findCalls(tree ast.Node, info *types.Info) map[ast.Node]struct{} { + mayYield := map[ast.Node]struct{}{} + var stack []ast.Node + ast.Inspect(tree, func(node ast.Node) bool { + if node != nil { + stack = append(stack, node) + + if c, ok := node.(*ast.CallExpr); ok { + // Exclude some call expressions. + switch fn := c.Fun.(type) { + case *ast.Ident: + if obj := info.ObjectOf(fn); obj != nil { + if obj == types.Universe.Lookup(fn.Name) { + return true // skip builtin function calls + } else if _, ok := obj.(*types.TypeName); ok { + return true // skip type casts + } + } + } + + // Mark this node, and all nodes that lead to it. + addNodes: + for i := len(stack) - 1; i >= 0; i-- { + n := stack[i] + switch n.(type) { + case *ast.FuncDecl, *ast.FuncLit: + break addNodes + } + if _, ok := mayYield[n]; ok { + break + } + mayYield[n] = struct{}{} + } + } + } else { + stack = stack[:len(stack)-1] + } + return true + }) + + return mayYield +}