Skip to content

Commit

Permalink
Move gls code to internal package (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso authored Sep 26, 2023
2 parents c5b70a9 + e75c210 commit 2563fa8
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 25 deletions.
4 changes: 3 additions & 1 deletion coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package coroutine

import (
"errors"

"github.com/stealthrocket/coroutine/internal/gls"
)

// Coroutine instances expose APIs allowing the program to drive the execution
Expand Down Expand Up @@ -94,7 +96,7 @@ func Yield[R, S any](v R) S {
// The function panics when called on a stack where no active coroutine exists,
// or if the type parameters do not match those of the coroutine.
func LoadContext[R, S any]() *Context[R, S] {
switch c := loadContext(getg()).(type) {
switch c := gls.Context().Load().(type) {
case *Context[R, S]:
return c
case nil:
Expand Down
8 changes: 5 additions & 3 deletions coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"slices"
"strconv"

"github.com/stealthrocket/coroutine/internal/gls"
"github.com/stealthrocket/coroutine/internal/serde"
)

Expand All @@ -27,11 +28,12 @@ func (c Coroutine[R, S]) Next() (hasNext bool) {
return false
}

g := getg()
storeContext(g, c.ctx)
g := gls.Context()

g.Store(c.ctx)

defer func() {
clearContext(g)
g.Clear()

switch err := recover(); err {
case nil:
Expand Down
12 changes: 8 additions & 4 deletions coroutine_volatile.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

package coroutine

import "runtime"
import (
"runtime"

"github.com/stealthrocket/coroutine/internal/gls"
)

// New creates a new coroutine which executes f as entry point.
func New[R, S any](f func()) Coroutine[R, S] {
Expand All @@ -13,13 +17,13 @@ func New[R, S any](f func()) Coroutine[R, S] {
}

go func() {
g := getg()
storeContext(g, c)
g := gls.Context()
g.Store(c)

defer func() {
c.done = true
close(c.next)
clearContext(g)
g.Clear()
}()

<-c.next
Expand Down
2 changes: 1 addition & 1 deletion getg.go → internal/gls/getg.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package coroutine
package gls

// getg is like the compiler intrisinc runtime.getg which retrieves the current
// goroutine object.
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
24 changes: 18 additions & 6 deletions gls.go → internal/gls/gls.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package coroutine
package gls

import "sync"

Expand All @@ -23,26 +23,38 @@ import "sync"
// simple memory loads.
var (
gmutex sync.RWMutex
gstate map[uintptr]any
gstate map[G]any
)

func loadContext(g uintptr) any {
// G is a reference to a goroutine, and provides a way
// to load, store and clear a goroutine local context.
type G uintptr

// Context retrieves the goroutine local storage for contexts.
func Context() G {
return G(getg())
}

// Load loads the goroutine local context.
func (g G) Load() any {
gmutex.RLock()
v := gstate[g]
gmutex.RUnlock()
return v
}

func storeContext(g uintptr, c any) {
// Store stores the goroutine local context.
func (g G) Store(c any) {
gmutex.Lock()
if gstate == nil {
gstate = make(map[uintptr]any)
gstate = make(map[G]any)
}
gstate[g] = c
gmutex.Unlock()
}

func clearContext(g uintptr) {
// Clear clears the goroutine local context.
func (g G) Clear() {
gmutex.Lock()
delete(gstate, g)
gmutex.Unlock()
Expand Down
17 changes: 7 additions & 10 deletions gls_test.go → internal/gls/gls_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package coroutine
package gls

import "testing"

Expand All @@ -7,15 +7,15 @@ func TestGLS(t *testing.T) {

f := func(n int) {
defer close(c)
storeContext(getg(), n)
Context().Store(n)

load := func() int {
v, _ := loadContext(getg()).(int)
v, _ := Context().Load().(int)
return v
}

c <- load()
clearContext(getg())
Context().Clear()
c <- load()
}

Expand Down Expand Up @@ -43,27 +43,24 @@ func BenchmarkGLS(b *testing.B) {

b.Run("loadContext", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
g := getg()
for pb.Next() {
_ = loadContext(g)
_ = Context().Load()
}
})
})

b.Run("storeContext", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
g := getg()
for pb.Next() {
storeContext(g, 42)
Context().Store(42)
}
})
})

b.Run("clearContext", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
g := getg()
for pb.Next() {
clearContext(g)
Context().Clear()
}
})
})
Expand Down

0 comments on commit 2563fa8

Please sign in to comment.