diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..a347f3d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Launch Package", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/cmd/authorized_keys/main.go", + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/cmd/authorized_keys/main.go b/cmd/authorized_keys/main.go index a936e6b..827c795 100644 --- a/cmd/authorized_keys/main.go +++ b/cmd/authorized_keys/main.go @@ -6,12 +6,15 @@ import ( "log/slog" "os" "os/signal" + "runtime" "strings" "syscall" "time" + "github.com/antoniomika/syncmap" "github.com/charmbracelet/ssh" "github.com/charmbracelet/wish" + "github.com/google/uuid" "github.com/picosh/pubsub" ) @@ -32,7 +35,6 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware { return } - ctx := sesh.Context() cmd := strings.TrimSpace(args[0]) channel := args[1] logger := cfg.Logger.With( @@ -45,37 +47,37 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware { if cmd == "help" { wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}") } else if cmd == "sub" { - sub := &pubsub.Subscriber{ - Name: channel, + sub := &pubsub.Sub{ + ID: uuid.NewString(), Writer: sesh, - Chan: make(chan error), + Done: make(chan struct{}), + Data: make(chan []byte), } + go func() { - <-ctx.Done() - err := cfg.PubSub.UnSub(sub) - if err != nil { - wish.Errorln(sesh, err) - } + <-sesh.Context().Done() + sub.Cleanup() }() - err := cfg.PubSub.Sub(sub) + + err := cfg.PubSub.Sub(channel, sub) if err != nil { - wish.Errorln(sesh, err) + logger.Error("error from sub", slog.Any("error", err), slog.String("sub", sub.ID)) } } else if cmd == "pub" { - msg := &pubsub.Msg{ - Name: channel, + pub := &pubsub.Pub{ + ID: uuid.NewString(), + Done: make(chan struct{}), Reader: sesh, } + go func() { - <-ctx.Done() - err := cfg.PubSub.UnPub(msg) - if err != nil { - wish.Errorln(sesh, err) - } + <-sesh.Context().Done() + pub.Cleanup() }() - err := cfg.PubSub.Pub(msg) + + err := cfg.PubSub.Pub(channel, pub) if err != nil { - wish.Errorln(sesh, err) + logger.Error("error from pub", slog.Any("error", err), slog.String("pub", pub.ID)) } } else { wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}") @@ -94,12 +96,13 @@ func main() { cfg := &pubsub.Cfg{ Logger: logger, PubSub: &pubsub.PubSubMulticast{ - Logger: logger, - Chan: make(chan *pubsub.Subscriber), + Logger: logger, + Channels: syncmap.New[string, *pubsub.Channel](), }, } s, err := wish.NewServer( + ssh.NoPty(), wish.WithAddress(fmt.Sprintf("%s:%s", host, port)), wish.WithHostKeyPath("ssh_data/term_info_ed25519"), wish.WithAuthorizedKeys(keyPath), @@ -125,6 +128,26 @@ func main() { } }() + go func() { + for { + slog.Info("Debug Info", slog.Int("goroutines", runtime.NumGoroutine())) + select { + case <-time.After(5 * time.Second): + for _, channel := range cfg.PubSub.GetChannels("") { + slog.Info("channel online", slog.Any("channel", channel.Name)) + for _, pub := range cfg.PubSub.GetPubs(channel.Name) { + slog.Info("pub online", slog.Any("channel", channel.Name), slog.Any("pub", pub.ID)) + } + for _, sub := range cfg.PubSub.GetSubs(channel.Name) { + slog.Info("sub online", slog.Any("channel", channel.Name), slog.Any("sub", sub.ID)) + } + } + case <-done: + return + } + } + }() + <-done logger.Info("stopping SSH server") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) diff --git a/go.mod b/go.mod index 1323b00..133bdcc 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/antoniomika/syncmap v1.0.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/bubbletea v0.27.0 // indirect github.com/charmbracelet/keygen v0.5.1 // indirect diff --git a/go.sum b/go.sum index 39dbe2d..312459c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/antoniomika/syncmap v1.0.0 h1:iFSfbQFQOvHZILFZF+hqWosO0no+W9+uF4y2VEyMKWU= +github.com/antoniomika/syncmap v1.0.0/go.mod h1:fK2829foEYnO4riNfyUn0SHQZt4ue3DStYjGU+sJj38= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/charmbracelet/bubbletea v0.27.0 h1:Mznj+vvYuYagD9Pn2mY7fuelGvP0HAXtZYGgRBCbHvU= diff --git a/multicast.go b/multicast.go index c438909..fe391ff 100644 --- a/multicast.go +++ b/multicast.go @@ -1,123 +1,192 @@ package pubsub import ( - "fmt" + "errors" "io" "log/slog" - "time" + "strings" - "github.com/google/uuid" + "github.com/antoniomika/syncmap" ) -/* -multicast: +type PubSubMulticast struct { + Logger *slog.Logger + Channels *syncmap.Map[string, *Channel] +} - every pub event will be sent to all subs on a channel +func (b *PubSubMulticast) Cleanup() { + toRemove := []string{} + b.Channels.Range(func(I string, J *Channel) bool { + count := 0 + J.Pubs.Range(func(K string, V *Pub) bool { + count++ + return true + }) + + J.Subs.Range(func(K string, V *Sub) bool { + count++ + return true + }) + + if count == 0 { + J.Cleanup() + toRemove = append(toRemove, I) + } -bidirectional blocking: + return true + }) - both pub and sub will wait for at least one - message on a channel before completing -*/ -type PubSubMulticast struct { - Logger *slog.Logger - subs []*Subscriber - Chan chan *Subscriber + for _, channel := range toRemove { + b.Channels.Delete(channel) + } } -func (b *PubSubMulticast) GetSubs() []*Subscriber { - b.Logger.Info("getsubs") - return b.subs +func (b *PubSubMulticast) GetChannels(channelPrefix string) []*Channel { + var chans []*Channel + b.Channels.Range(func(I string, J *Channel) bool { + if strings.HasPrefix(I, channelPrefix) { + chans = append(chans, J) + } + + return true + }) + return chans } -func (b *PubSubMulticast) Sub(sub *Subscriber) error { - id := uuid.New() - sub.ID = id.String() - b.Logger.Info("sub", "channel", sub.Name, "id", id) - b.subs = append(b.subs, sub) - select { - case b.Chan <- sub: - // message sent - default: - // message dropped - } +func (b *PubSubMulticast) GetChannel(channel string) *Channel { + channelData, _ := b.Channels.Load(channel) + return channelData +} + +func (b *PubSubMulticast) GetPubs(channel string) []*Pub { + var pubs []*Pub + b.Channels.Range(func(I string, J *Channel) bool { + found := channel == I + if found || channel == "*" { + J.Pubs.Range(func(K string, V *Pub) bool { + pubs = append(pubs, V) + return true + }) + } - return sub.Wait() + return !found + }) + return pubs } -func (b *PubSubMulticast) UnSub(rm *Subscriber) error { - b.Logger.Info("unsub", "channel", rm.Name, "id", rm.ID) - next := []*Subscriber{} - for _, sub := range b.subs { - if sub.ID != rm.ID { - next = append(next, sub) +func (b *PubSubMulticast) GetSubs(channel string) []*Sub { + var subs []*Sub + b.Channels.Range(func(I string, J *Channel) bool { + found := channel == I + if found || channel == "*" { + J.Subs.Range(func(K string, V *Sub) bool { + subs = append(subs, V) + return true + }) } - } - b.subs = next - return nil + + return !found + }) + return subs } -func (b *PubSubMulticast) PubMatcher(msg *Msg, sub *Subscriber) bool { - return msg.Name == sub.Name +func (b *PubSubMulticast) ensure(channel string) *Channel { + dataChannel, _ := b.Channels.LoadOrStore(channel, &Channel{ + Name: channel, + Done: make(chan struct{}), + Data: make(chan []byte), + Subs: syncmap.New[string, *Sub](), + Pubs: syncmap.New[string, *Pub](), + }) + dataChannel.Handle() + + return dataChannel } -func (b *PubSubMulticast) Pub(msg *Msg) error { - log := b.Logger.With("channel", msg.Name) - log.Info("pub") - - matches := []*Subscriber{} - writers := []io.Writer{} - for _, sub := range b.subs { - if b.PubMatcher(msg, sub) { - log.Info("found match", "sub", sub.ID) - matches = append(matches, sub) - writers = append(writers, sub.Writer) - } - } +func (b *PubSubMulticast) Sub(channel string, sub *Sub) error { + dataChannel := b.ensure(channel) + dataChannel.Subs.Store(sub.ID, sub) + defer func() { + sub.Cleanup() + dataChannel.Subs.Delete(sub.ID) + b.Cleanup() + }() + +mainLoop: + for { + select { + case <-sub.Done: + break mainLoop + case <-dataChannel.Done: + break mainLoop + case data, ok := <-sub.Data: + _, err := sub.Writer.Write(data) + if err != nil { + slog.Error("error writing to sub", slog.Any("sub", sub.ID), slog.Any("channel", channel), slog.Any("error", err)) + return err + } - if len(matches) == 0 { - var sub *Subscriber - for { - log.Info("no subs found, waiting for sub") - sub = <-b.Chan - if b.PubMatcher(msg, sub) { - // empty subscriber is a signal to force a pub to stop - // waiting for a sub - if sub.Writer == nil { - return fmt.Errorf("pub closed") - } - return b.Pub(msg) + if !ok { + break mainLoop } } } - log.Info("copying data") - del := time.Now() - msg.SentAt = &del + return nil +} - writer := io.MultiWriter(writers...) - _, err := io.Copy(writer, msg.Reader) - if err != nil { - log.Error("pub", "err", err) - } - for _, sub := range matches { - sub.Chan <- err - log.Info("sub unsub") - err = b.UnSub(sub) - if err != nil { - log.Error("unsub err", "err", err) +func (b *PubSubMulticast) Pub(channel string, pub *Pub) error { + dataChannel := b.ensure(channel) + dataChannel.Pubs.Store(pub.ID, pub) + defer func() { + pub.Cleanup() + dataChannel.Pubs.Delete(pub.ID) + + count := 0 + dataChannel.Pubs.Range(func(I string, J *Pub) bool { + count++ + return true + }) + + if count == 0 { + dataChannel.onceData.Do(func() { + close(dataChannel.Data) + }) } - } - return err -} + b.Cleanup() + }() + +mainLoop: + for { + select { + case <-pub.Done: + break mainLoop + case <-dataChannel.Done: + break mainLoop + default: + data := make([]byte, 32*1024) + n, err := pub.Reader.Read(data) + data = data[:n] + + select { + case dataChannel.Data <- data: + case <-pub.Done: + break mainLoop + case <-dataChannel.Done: + break mainLoop + } + + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } -func (b *PubSubMulticast) UnPub(msg *Msg) error { - b.Logger.Info("unpub", "channel", msg.Name) - // if the message hasn't been delivered then send a cancel sub to - // the multicast channel - if msg.SentAt == nil { - b.Chan <- &Subscriber{Name: msg.Name} + slog.Error("error reading from pub", slog.Any("pub", pub.ID), slog.Any("channel", channel), slog.Any("error", err)) + return err + } + } } + return nil } diff --git a/multicast_test.go b/multicast_test.go index e407c1d..d2e4109 100644 --- a/multicast_test.go +++ b/multicast_test.go @@ -5,44 +5,76 @@ import ( "fmt" "log/slog" "strings" + "sync" "testing" + + "github.com/antoniomika/syncmap" ) +type Buffer struct { + b bytes.Buffer + m sync.Mutex +} + +func (b *Buffer) Read(p []byte) (n int, err error) { + b.m.Lock() + defer b.m.Unlock() + return b.b.Read(p) +} +func (b *Buffer) Write(p []byte) (n int, err error) { + b.m.Lock() + defer b.m.Unlock() + return b.b.Write(p) +} +func (b *Buffer) String() string { + b.m.Lock() + defer b.m.Unlock() + return b.b.String() +} + func TestMulticastSubBlock(t *testing.T) { orderActual := "" orderExpected := "sub-pub-" - actual := new(bytes.Buffer) + actual := new(Buffer) expected := "some test data" name := "test-channel" syncer := make(chan int) cast := &PubSubMulticast{ - Logger: slog.Default(), - Chan: make(chan *Subscriber), + Logger: slog.Default(), + Channels: syncmap.New[string, *Channel](), } + var wg sync.WaitGroup + wg.Add(2) + go func() { - sub := &Subscriber{ + sub := &Sub{ ID: "1", - Name: name, - Chan: make(chan error), Writer: actual, + Done: make(chan struct{}), + Data: make(chan []byte), } orderActual += "sub-" syncer <- 0 - fmt.Println(cast.Sub(sub)) + fmt.Println(cast.Sub(name, sub)) + wg.Done() }() <-syncer go func() { - msg := &Msg{Name: name, Reader: strings.NewReader(expected)} + pub := &Pub{ + ID: "1", + Done: make(chan struct{}), + Reader: strings.NewReader(expected), + } orderActual += "pub-" - syncer <- 0 - fmt.Println(cast.Pub(msg)) + fmt.Println(cast.Pub(name, pub)) + wg.Done() }() - <-syncer + wg.Wait() if orderActual != orderExpected { t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected) @@ -55,38 +87,46 @@ func TestMulticastSubBlock(t *testing.T) { func TestMulticastPubBlock(t *testing.T) { orderActual := "" orderExpected := "pub-sub-" - actual := new(bytes.Buffer) + actual := new(Buffer) expected := "some test data" name := "test-channel" syncer := make(chan int) cast := &PubSubMulticast{ - Logger: slog.Default(), - Chan: make(chan *Subscriber), + Logger: slog.Default(), + Channels: syncmap.New[string, *Channel](), } + var wg sync.WaitGroup + wg.Add(2) + go func() { - msg := &Msg{Name: name, Reader: strings.NewReader(expected)} + pub := &Pub{ + ID: "1", + Done: make(chan struct{}), + Reader: strings.NewReader(expected), + } orderActual += "pub-" syncer <- 0 - fmt.Println(cast.Pub(msg)) + fmt.Println(cast.Pub(name, pub)) + wg.Done() }() <-syncer go func() { - sub := &Subscriber{ + sub := &Sub{ ID: "1", - Name: name, - Chan: make(chan error), Writer: actual, + Done: make(chan struct{}), + Data: make(chan []byte), } orderActual += "sub-" - syncer <- 0 - fmt.Println(cast.Sub(sub)) + wg.Done() + fmt.Println(cast.Sub(name, sub)) }() - <-syncer + wg.Wait() if orderActual != orderExpected { t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected) @@ -99,53 +139,62 @@ func TestMulticastPubBlock(t *testing.T) { func TestMulticastMultSubs(t *testing.T) { orderActual := "" orderExpected := "sub-sub-pub-" - actual := new(bytes.Buffer) - actualOther := new(bytes.Buffer) + actual := new(Buffer) + actualOther := new(Buffer) expected := "some test data" name := "test-channel" syncer := make(chan int) cast := &PubSubMulticast{ - Logger: slog.Default(), - Chan: make(chan *Subscriber), + Logger: slog.Default(), + Channels: syncmap.New[string, *Channel](), } + var wg sync.WaitGroup + wg.Add(3) + go func() { - sub := &Subscriber{ + sub := &Sub{ ID: "1", - Name: name, - Chan: make(chan error), Writer: actual, + Done: make(chan struct{}), + Data: make(chan []byte), } orderActual += "sub-" syncer <- 0 - fmt.Println(cast.Sub(sub)) + fmt.Println(cast.Sub(name, sub)) + wg.Done() }() <-syncer go func() { - sub := &Subscriber{ + sub := &Sub{ ID: "2", - Name: name, - Chan: make(chan error), Writer: actualOther, + Done: make(chan struct{}), + Data: make(chan []byte), } orderActual += "sub-" syncer <- 0 - fmt.Println(cast.Sub(sub)) + fmt.Println(cast.Sub(name, sub)) + wg.Done() }() <-syncer go func() { - msg := &Msg{Name: name, Reader: strings.NewReader(expected)} + pub := &Pub{ + ID: "1", + Done: make(chan struct{}), + Reader: strings.NewReader(expected), + } orderActual += "pub-" - syncer <- 0 - fmt.Println(cast.Pub(msg)) + fmt.Println(cast.Pub(name, pub)) + wg.Done() }() - <-syncer + wg.Wait() if orderActual != orderExpected { t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected) diff --git a/pubsub.go b/pubsub.go index a6c200d..6cac450 100644 --- a/pubsub.go +++ b/pubsub.go @@ -3,35 +3,132 @@ package pubsub import ( "io" "log/slog" + "sync" "time" + + "github.com/antoniomika/syncmap" ) -type Subscriber struct { - ID string - Name string - Chan chan error - Writer io.Writer +type Channel struct { + Name string + Done chan struct{} + Data chan []byte + Subs *syncmap.Map[string, *Sub] + Pubs *syncmap.Map[string, *Pub] + once sync.Once + cleanupOnce sync.Once + onceData sync.Once +} + +func (c *Channel) Cleanup() { + c.cleanupOnce.Do(func() { + close(c.Done) + c.onceData.Do(func() { + close(c.Data) + }) + }) } -func (s *Subscriber) Wait() error { - err := <-s.Chan - return err +func (c *Channel) Handle() { + c.once.Do(func() { + go func() { + defer func() { + c.Subs.Range(func(I string, J *Sub) bool { + J.Cleanup() + return true + }) + + c.Pubs.Range(func(I string, J *Pub) bool { + J.Cleanup() + return true + }) + }() + + mainLoop: + for { + select { + case <-c.Done: + return + case data, ok := <-c.Data: + count := 0 + for count == 0 { + c.Subs.Range(func(I string, J *Sub) bool { + count++ + return true + }) + if count == 0 { + select { + case <-time.After(1 * time.Millisecond): + case <-c.Done: + break mainLoop + } + } + } + + c.Subs.Range(func(I string, J *Sub) bool { + if !ok { + J.onceData.Do(func() { + close(J.Data) + }) + return true + } + + select { + case J.Data <- data: + return true + case <-J.Done: + return true + case <-c.Done: + return true + case <-time.After(1 * time.Second): + slog.Error("timeout writing to sub", slog.Any("sub", I), slog.Any("channel", c.Name)) + return true + } + }) + } + } + }() + }) } -type Msg struct { - Name string +type Sub struct { + ID string + Done chan struct{} + Data chan []byte + Writer io.Writer + once sync.Once + onceData sync.Once +} + +func (sub *Sub) Cleanup() { + sub.once.Do(func() { + close(sub.Done) + sub.onceData.Do(func() { + close(sub.Data) + }) + }) +} + +type Pub struct { + ID string + Done chan struct{} Reader io.Reader - SentAt *time.Time + once sync.Once +} + +func (pub *Pub) Cleanup() { + pub.once.Do(func() { + close(pub.Done) + }) } type PubSub interface { - GetSubs() []*Subscriber - Sub(sub *Subscriber) error - UnSub(sub *Subscriber) error - Pub(msg *Msg) error - UnPub(msg *Msg) error - // return true if message should be sent to this subscriber - PubMatcher(msg *Msg, sub *Subscriber) bool + GetSubs(channel string) []*Sub + GetPubs(channel string) []*Pub + GetChannels(channelPrefix string) []*Channel + GetChannel(channel string) *Channel + Sub(channel string, sub *Sub) error + Pub(channel string, pub *Pub) error } type Cfg struct {