diff --git a/compiler/dispatch.go b/compiler/dispatch.go index c108ff2..04fa507 100644 --- a/compiler/dispatch.go +++ b/compiler/dispatch.go @@ -80,29 +80,66 @@ func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan) ast case *ast.ForStmt: forSpan := dispatchSpans[s] s.Body = compileDispatch(s.Body, dispatchSpans).(*ast.BlockStmt) - // Reset IP after each loop iteration. - ipVar := &ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")} - ipVal := &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(forSpan.start)} - switch post := s.Post.(type) { - case nil: - s.Post = &ast.AssignStmt{Lhs: []ast.Expr{ipVar}, Tok: token.ASSIGN, Rhs: []ast.Expr{ipVal}} - case *ast.IncDecStmt: + + // Hijack the loop's post iteration statement to inject an IP reset. + if s.Post == nil { + s.Post = &ast.AssignStmt{Lhs: []ast.Expr{}, Tok: token.ASSIGN, Rhs: []ast.Expr{}} + } else if incDec, ok := s.Post.(*ast.IncDecStmt); ok { var op token.Token - switch post.Tok { + switch incDec.Tok { case token.INC: op = token.ADD case token.DEC: op = token.SUB } s.Post = &ast.AssignStmt{ - Lhs: []ast.Expr{post.X, ipVar}, + Lhs: []ast.Expr{incDec.X}, Tok: token.ASSIGN, - Rhs: []ast.Expr{ - &ast.BinaryExpr{X: post.X, Op: op, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}}, - ipVal, - }, + Rhs: []ast.Expr{&ast.BinaryExpr{X: incDec.X, Op: op, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}}}, + } + } + assign, ok := s.Post.(*ast.AssignStmt) + if !ok { + panic("not implemented") + } + if assign.Tok != token.ASSIGN { + for i := range assign.Lhs { + var op token.Token + switch assign.Tok { + case token.ADD_ASSIGN: + op = token.ADD + case token.SUB_ASSIGN: + op = token.SUB + case token.MUL_ASSIGN: + op = token.MUL + case token.QUO_ASSIGN: + op = token.QUO + case token.REM_ASSIGN: + op = token.REM + case token.AND_ASSIGN: + op = token.AND + case token.OR_ASSIGN: + op = token.OR + case token.XOR_ASSIGN: + op = token.XOR + case token.SHL_ASSIGN: + op = token.SHL + case token.SHR_ASSIGN: + op = token.SHR + case token.AND_NOT_ASSIGN: + op = token.AND_NOT + } + // From the Go language spec: + // > An assignment operation x op= y where op is a binary arithmetic operator is equivalent to x = x op (y) but evaluates x only once. + // Thus, this transformation is only valid if the LHS doesn't + // contain side effects. This is checked elsewhere. + assign.Rhs[i] = &ast.BinaryExpr{X: assign.Lhs[i], Op: op, Y: assign.Rhs[i]} } + assign.Tok = token.ASSIGN } + assign.Lhs = append(assign.Lhs, &ast.SelectorExpr{X: ast.NewIdent("_f"), Sel: ast.NewIdent("IP")}) + assign.Rhs = append(assign.Rhs, &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(forSpan.start)}) + case *ast.SwitchStmt: for i, child := range s.Body.List { s.Body.List[i] = compileDispatch(child, dispatchSpans) diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index fb45381..a2debc0 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -1543,11 +1543,14 @@ func Select(n int) { func init() { serde.RegisterType[**byte]() serde.RegisterType[*[100000]uintptr]() + serde.RegisterType[*[1125899906842623]byte]() serde.RegisterType[*[131072]uint16]() serde.RegisterType[*[140737488355327]byte]() serde.RegisterType[*[16]byte]() serde.RegisterType[*[171]uint8]() serde.RegisterType[*[1]uintptr]() + serde.RegisterType[*[268435456]uintptr]() + serde.RegisterType[*[281474976710655]uint32]() serde.RegisterType[*[2]byte]() serde.RegisterType[*[2]float32]() serde.RegisterType[*[2]float64]() @@ -1557,11 +1560,12 @@ func init() { serde.RegisterType[*[32]rune]() serde.RegisterType[*[32]uintptr]() serde.RegisterType[*[4]byte]() - serde.RegisterType[*[512]uintptr]() + serde.RegisterType[*[562949953421311]uint16]() serde.RegisterType[*[65536]uintptr]() serde.RegisterType[*[70368744177663]uint16]() serde.RegisterType[*[8]byte]() serde.RegisterType[*[8]uint8]() + serde.RegisterType[*[]byte]() serde.RegisterType[*[]uint64]() serde.RegisterType[*bool]() serde.RegisterType[*byte]() @@ -1574,77 +1578,80 @@ func init() { serde.RegisterType[*uint64]() serde.RegisterType[*uint8]() serde.RegisterType[*uintptr]() + serde.RegisterType[[0]byte]() + serde.RegisterType[[0]uint8]() serde.RegisterType[[0]uintptr]() serde.RegisterType[[1000]uintptr]() serde.RegisterType[[100]byte]() serde.RegisterType[[1024]bool]() serde.RegisterType[[1024]byte]() - serde.RegisterType[[1024]int8]() serde.RegisterType[[1024]uint8]() serde.RegisterType[[1048576]uint8]() serde.RegisterType[[104]byte]() - serde.RegisterType[[104]int8]() - serde.RegisterType[[107]string]() serde.RegisterType[[108]byte]() + serde.RegisterType[[108]int8]() serde.RegisterType[[10]byte]() serde.RegisterType[[10]string]() serde.RegisterType[[128]byte]() serde.RegisterType[[128]uint64]() serde.RegisterType[[128]uintptr]() serde.RegisterType[[129]uint8]() - serde.RegisterType[[12]int8]() serde.RegisterType[[131072]uintptr]() + serde.RegisterType[[133]string]() serde.RegisterType[[13]int32]() serde.RegisterType[[14]byte]() serde.RegisterType[[14]int8]() + serde.RegisterType[[15]uint64]() serde.RegisterType[[16384]byte]() serde.RegisterType[[16384]uint8]() serde.RegisterType[[16]byte]() - serde.RegisterType[[16]int8]() + serde.RegisterType[[16]int64]() serde.RegisterType[[16]uint64]() - serde.RegisterType[[16]uintptr]() serde.RegisterType[[17]string]() serde.RegisterType[[1]byte]() - serde.RegisterType[[1]uint32]() serde.RegisterType[[1]uint64]() serde.RegisterType[[1]uint8]() serde.RegisterType[[1]uintptr]() serde.RegisterType[[20]byte]() - serde.RegisterType[[20]uint8]() serde.RegisterType[[21]byte]() + serde.RegisterType[[23]uint64]() serde.RegisterType[[249]uint8]() serde.RegisterType[[24]byte]() + serde.RegisterType[[24]uint32]() serde.RegisterType[[252]uintptr]() serde.RegisterType[[253]uintptr]() + serde.RegisterType[[256]int8]() serde.RegisterType[[256]uint64]() - serde.RegisterType[[29]uint64]() serde.RegisterType[[2]byte]() serde.RegisterType[[2]int]() serde.RegisterType[[2]int32]() - serde.RegisterType[[2]int64]() - serde.RegisterType[[2]uint32]() serde.RegisterType[[2]uint64]() serde.RegisterType[[2]uintptr]() serde.RegisterType[[32]byte]() - serde.RegisterType[[32]int32]() serde.RegisterType[[32]string]() - serde.RegisterType[[32]uint32]() + serde.RegisterType[[32]uint8]() serde.RegisterType[[32]uintptr]() serde.RegisterType[[33]float64]() serde.RegisterType[[3]byte]() serde.RegisterType[[3]int]() + serde.RegisterType[[3]int64]() + serde.RegisterType[[3]uint16]() + serde.RegisterType[[3]uint32]() + serde.RegisterType[[3]uint64]() serde.RegisterType[[4096]byte]() - serde.RegisterType[[40]int8]() + serde.RegisterType[[40]byte]() + serde.RegisterType[[44]byte]() serde.RegisterType[[4]byte]() serde.RegisterType[[4]float64]() + serde.RegisterType[[4]int64]() serde.RegisterType[[4]string]() + serde.RegisterType[[4]uint16]() serde.RegisterType[[4]uint32]() serde.RegisterType[[4]uint64]() serde.RegisterType[[4]uintptr]() serde.RegisterType[[50]uintptr]() serde.RegisterType[[512]byte]() serde.RegisterType[[512]uintptr]() - serde.RegisterType[[56]int8]() serde.RegisterType[[5]byte]() serde.RegisterType[[5]uint]() serde.RegisterType[[61]struct { @@ -1654,9 +1661,11 @@ func init() { }]() serde.RegisterType[[64488]byte]() serde.RegisterType[[64]byte]() - serde.RegisterType[[64]uint64]() serde.RegisterType[[64]uintptr]() serde.RegisterType[[65528]byte]() + serde.RegisterType[[65]int8]() + serde.RegisterType[[65]uint32]() + serde.RegisterType[[65]uintptr]() serde.RegisterType[[68]struct { Size uint32 Mallocs uint64 @@ -1666,18 +1675,20 @@ func init() { serde.RegisterType[[68]uint32]() serde.RegisterType[[68]uint64]() serde.RegisterType[[68]uint8]() + serde.RegisterType[[6]byte]() serde.RegisterType[[6]int]() + serde.RegisterType[[6]int8]() serde.RegisterType[[6]uintptr]() - serde.RegisterType[[7]uint64]() - serde.RegisterType[[88]byte]() + serde.RegisterType[[8192]byte]() serde.RegisterType[[8]byte]() - serde.RegisterType[[8]int8]() serde.RegisterType[[8]string]() serde.RegisterType[[8]uint32]() + serde.RegisterType[[8]uint64]() serde.RegisterType[[8]uint8]() - serde.RegisterType[[92]int8]() serde.RegisterType[[96]byte]() + serde.RegisterType[[96]int8]() serde.RegisterType[[9]string]() + serde.RegisterType[[9]uintptr]() serde.RegisterType[[]*byte]() serde.RegisterType[[][]int32]() serde.RegisterType[[]byte]() @@ -1743,53 +1754,10 @@ func init() { needed bool alignme uint64 }]() - serde.RegisterType[struct { - fd int32 - cmd int32 - arg int32 - ret int32 - errno int32 - }]() serde.RegisterType[struct { fill uint64 capacity uint64 }]() - serde.RegisterType[struct { - fn uintptr - a1 uintptr - a2 uintptr - a3 uintptr - a4 uintptr - a5 uintptr - a6 uintptr - r1 uintptr - r2 uintptr - err uintptr - }]() - serde.RegisterType[struct { - fn uintptr - a1 uintptr - a2 uintptr - a3 uintptr - a4 uintptr - a5 uintptr - f1 float64 - r1 uintptr - }]() - serde.RegisterType[struct { - fn uintptr - a1 uintptr - a2 uintptr - a3 uintptr - r1 uintptr - r2 uintptr - err uintptr - }]() - serde.RegisterType[struct { - t int64 - numer uint32 - denom uint32 - }]() serde.RegisterType[struct { tick uint64 i int @@ -1802,65 +1770,73 @@ func init() { serde.RegisterType[sync.Pool]() serde.RegisterType[sync.RWMutex]() serde.RegisterType[sync.WaitGroup]() - serde.RegisterType[syscall.BpfHdr]() - serde.RegisterType[syscall.BpfInsn]() - serde.RegisterType[syscall.BpfProgram]() - serde.RegisterType[syscall.BpfStat]() - serde.RegisterType[syscall.BpfVersion]() serde.RegisterType[syscall.Cmsghdr]() serde.RegisterType[syscall.Credential]() serde.RegisterType[syscall.Dirent]() + serde.RegisterType[syscall.EpollEvent]() serde.RegisterType[syscall.Errno]() - serde.RegisterType[syscall.Fbootstraptransfer_t]() serde.RegisterType[syscall.FdSet]() serde.RegisterType[syscall.Flock_t]() serde.RegisterType[syscall.Fsid]() - serde.RegisterType[syscall.Fstore_t]() serde.RegisterType[syscall.ICMPv6Filter]() serde.RegisterType[syscall.IPMreq]() + serde.RegisterType[syscall.IPMreqn]() serde.RegisterType[syscall.IPv6MTUInfo]() serde.RegisterType[syscall.IPv6Mreq]() - serde.RegisterType[syscall.IfData]() - serde.RegisterType[syscall.IfMsghdr]() - serde.RegisterType[syscall.IfaMsghdr]() - serde.RegisterType[syscall.IfmaMsghdr]() - serde.RegisterType[syscall.IfmaMsghdr2]() + serde.RegisterType[syscall.IfAddrmsg]() + serde.RegisterType[syscall.IfInfomsg]() serde.RegisterType[syscall.Inet4Pktinfo]() serde.RegisterType[syscall.Inet6Pktinfo]() - serde.RegisterType[syscall.InterfaceAddrMessage]() - serde.RegisterType[syscall.InterfaceMessage]() - serde.RegisterType[syscall.InterfaceMulticastAddrMessage]() + serde.RegisterType[syscall.InotifyEvent]() serde.RegisterType[syscall.Iovec]() - serde.RegisterType[syscall.Kevent_t]() serde.RegisterType[syscall.Linger]() - serde.RegisterType[syscall.Log2phys_t]() serde.RegisterType[syscall.Msghdr]() + serde.RegisterType[syscall.NetlinkMessage]() + serde.RegisterType[syscall.NetlinkRouteAttr]() + serde.RegisterType[syscall.NetlinkRouteRequest]() + serde.RegisterType[syscall.NlAttr]() + serde.RegisterType[syscall.NlMsgerr]() + serde.RegisterType[syscall.NlMsghdr]() serde.RegisterType[syscall.ProcAttr]() - serde.RegisterType[syscall.Radvisory_t]() + serde.RegisterType[syscall.PtraceRegs]() serde.RegisterType[syscall.RawSockaddr]() serde.RegisterType[syscall.RawSockaddrAny]() - serde.RegisterType[syscall.RawSockaddrDatalink]() serde.RegisterType[syscall.RawSockaddrInet4]() serde.RegisterType[syscall.RawSockaddrInet6]() + serde.RegisterType[syscall.RawSockaddrLinklayer]() + serde.RegisterType[syscall.RawSockaddrNetlink]() serde.RegisterType[syscall.RawSockaddrUnix]() serde.RegisterType[syscall.Rlimit]() - serde.RegisterType[syscall.RouteMessage]() - serde.RegisterType[syscall.RtMetrics]() - serde.RegisterType[syscall.RtMsghdr]() + serde.RegisterType[syscall.RtAttr]() + serde.RegisterType[syscall.RtGenmsg]() + serde.RegisterType[syscall.RtMsg]() + serde.RegisterType[syscall.RtNexthop]() serde.RegisterType[syscall.Rusage]() serde.RegisterType[syscall.Signal]() - serde.RegisterType[syscall.SockaddrDatalink]() + serde.RegisterType[syscall.SockFilter]() + serde.RegisterType[syscall.SockFprog]() serde.RegisterType[syscall.SockaddrInet4]() serde.RegisterType[syscall.SockaddrInet6]() + serde.RegisterType[syscall.SockaddrLinklayer]() + serde.RegisterType[syscall.SockaddrNetlink]() serde.RegisterType[syscall.SockaddrUnix]() serde.RegisterType[syscall.SocketControlMessage]() serde.RegisterType[syscall.Stat_t]() serde.RegisterType[syscall.Statfs_t]() serde.RegisterType[syscall.SysProcAttr]() + serde.RegisterType[syscall.SysProcIDMap]() + serde.RegisterType[syscall.Sysinfo_t]() + serde.RegisterType[syscall.TCPInfo]() serde.RegisterType[syscall.Termios]() + serde.RegisterType[syscall.Time_t]() serde.RegisterType[syscall.Timespec]() serde.RegisterType[syscall.Timeval]() - serde.RegisterType[syscall.Timeval32]() + serde.RegisterType[syscall.Timex]() + serde.RegisterType[syscall.Tms]() + serde.RegisterType[syscall.Ucred]() + serde.RegisterType[syscall.Ustat_t]() + serde.RegisterType[syscall.Utimbuf]() + serde.RegisterType[syscall.Utsname]() serde.RegisterType[syscall.WaitStatus]() serde.RegisterType[time.Duration]() serde.RegisterType[time.Location]() diff --git a/compiler/unsupported.go b/compiler/unsupported.go index 739c25a..b7d2b97 100644 --- a/compiler/unsupported.go +++ b/compiler/unsupported.go @@ -45,16 +45,25 @@ func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { err = fmt.Errorf("not implemented: labels not attached to for/switch/select") } case *ast.ForStmt: - // Only very simple for loop post iteration statements - // are supported. + // Only simple post iteration statements are supported. + var exprs []ast.Expr switch p := n.Post.(type) { case nil: case *ast.IncDecStmt: - if _, ok := p.X.(*ast.Ident); !ok { - err = fmt.Errorf("not implemented: for post inc/dec %T", p.X) + exprs = append(exprs, p.X) + case *ast.AssignStmt: + if len(p.Lhs) != len(p.Rhs) { + err = fmt.Errorf("not implemented: for loop post iteration assignment with unbalanced sides") } + exprs = append(exprs, p.Lhs...) + exprs = append(exprs, p.Rhs...) default: - err = fmt.Errorf("not implemented: for post %T", p) + err = fmt.Errorf("not implemented: for loop post iteration statement %T", p) + } + for _, e := range exprs { + if countFunctionCalls(e, info) > 0 { + err = fmt.Errorf("not implemented: for loop post iteration statement with function call") + } } // Fully supported: