Skip to content

Commit

Permalink
refactor: create lib
Browse files Browse the repository at this point in the history
  • Loading branch information
neurosnap committed Aug 30, 2024
1 parent 38fc90b commit 6da6679
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 146 deletions.
150 changes: 4 additions & 146 deletions cmd/authorized_keys/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@ package main
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/google/uuid"
"github.com/picosh/pubsub"
)

func GetEnv(key string, defaultVal string) string {
Expand All @@ -23,160 +20,21 @@ func GetEnv(key string, defaultVal string) string {
return defaultVal
}

type Subscriber struct {
ID string
Name string
Session ssh.Session
Chan chan error
}

func (s *Subscriber) Wait() error {
err := <-s.Chan
return err
}

type Msg struct {
Name string
Reader io.Reader
}

type PubSub interface {
GetSubs() []*Subscriber
Sub(l *Subscriber) error
UnSub(l *Subscriber) error
Pub(msg *Msg) error
}

type PubSubMulticast struct {
logger *slog.Logger
subs []*Subscriber
}

func (b *PubSubMulticast) GetSubs() []*Subscriber {
b.logger.Info("getsubs")
return b.subs
}

func (b *PubSubMulticast) Sub(sub *Subscriber) error {
id := uuid.New()
sub.ID = id.String()
b.logger.Info("sub", "channel", sub.Name, "id", id)
b.subs = append(b.subs, sub)
return sub.Wait()
}

func (b *PubSubMulticast) UnSub(rm *Subscriber) error {
b.logger.Info("unsub", "channel", rm.Name, "id", rm.ID)
next := []*Subscriber{}
for _, sub := range b.subs {
if sub.ID != rm.ID {
next = append(next, sub)
}
}
b.subs = next
return nil
}

func (b *PubSubMulticast) Pub(msg *Msg) error {
log := b.logger.With("channel", msg.Name)
log.Info("pub")

matches := []*Subscriber{}
writers := []io.Writer{}
for _, sub := range b.subs {
if sub.Name == msg.Name {
matches = append(matches, sub)
writers = append(writers, sub.Session)
}
}

log.Info("copying data")
writer := io.MultiWriter(writers...)
_, err := io.Copy(writer, msg.Reader)
if err != nil {
log.Error("pub", "err", err)
}
for _, sub := range matches {
sub.Chan <- err
b.UnSub(sub)
}

return err
}

type Cfg struct {
Logger *slog.Logger
PubSub PubSub
}

func PubSubMiddleware(cfg *Cfg) wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(sesh ssh.Session) {
args := sesh.Command()
if len(args) < 2 {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
next(sesh)
return
}

cmd := strings.TrimSpace(args[0])
channel := args[1]
logger := cfg.Logger.With(
"cmd", cmd,
"channel", channel,
)

logger.Info("running cli")

if cmd == "help" {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
} else if cmd == "sub" {
sub := &Subscriber{
Name: channel,
Session: sesh,
Chan: make(chan error),
}
err := cfg.PubSub.Sub(sub)
if err != nil {
wish.Errorln(sesh, err)
}
/* defer func() {
err = cfg.PubSub.UnSub(listener)
if err != nil {
wish.Errorln(sesh, err)
}
}() */
} else if cmd == "pub" {
msg := &Msg{
Name: channel,
Reader: sesh,
}
err := cfg.PubSub.Pub(msg)
wish.Errorln(sesh, err)
} else {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
}

next(sesh)
}
}
}

func main() {
logger := slog.Default()
host := GetEnv("SSH_HOST", "0.0.0.0")
port := GetEnv("SSH_PORT", "2222")
keyPath := GetEnv("SSH_AUTHORIZED_KEYS", "./ssh_data/authorized_keys")
cfg := &Cfg{
cfg := &pubsub.Cfg{
Logger: logger,
PubSub: &PubSubMulticast{logger: logger},
PubSub: &pubsub.PubSubMulticast{Logger: logger},
}

s, err := wish.NewServer(
wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
wish.WithAuthorizedKeys(keyPath),
wish.WithMiddleware(PubSubMiddleware(cfg)),
wish.WithMiddleware(pubsub.PubSubMiddleware(cfg)),
)
if err != nil {
logger.Error(err.Error())
Expand Down
154 changes: 154 additions & 0 deletions pubsub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package pubsub

import (
"io"
"log/slog"
"strings"

"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/google/uuid"
)

type Subscriber struct {
ID string
Name string
Session ssh.Session
Chan chan error
}

func (s *Subscriber) Wait() error {
err := <-s.Chan
return err
}

type Msg struct {
Name string
Reader io.Reader
}

type PubSub interface {
GetSubs() []*Subscriber
Sub(l *Subscriber) error
UnSub(l *Subscriber) error
Pub(msg *Msg) error
}

type PubSubMulticast struct {
Logger *slog.Logger
subs []*Subscriber
}

func (b *PubSubMulticast) GetSubs() []*Subscriber {
b.Logger.Info("getsubs")
return b.subs
}

func (b *PubSubMulticast) Sub(sub *Subscriber) error {
id := uuid.New()
sub.ID = id.String()
b.Logger.Info("sub", "channel", sub.Name, "id", id)
b.subs = append(b.subs, sub)
return sub.Wait()
}

func (b *PubSubMulticast) UnSub(rm *Subscriber) error {
b.Logger.Info("unsub", "channel", rm.Name, "id", rm.ID)
next := []*Subscriber{}
for _, sub := range b.subs {
if sub.ID != rm.ID {
next = append(next, sub)
}
}
b.subs = next
return nil
}

func (b *PubSubMulticast) Pub(msg *Msg) error {
log := b.Logger.With("channel", msg.Name)
log.Info("pub")

matches := []*Subscriber{}
writers := []io.Writer{}
for _, sub := range b.subs {
if sub.Name == msg.Name {
matches = append(matches, sub)
writers = append(writers, sub.Session)
}
}

if len(matches) == 0 {
log.Info("no subs found")
}

log.Info("copying data")
writer := io.MultiWriter(writers...)
_, err := io.Copy(writer, msg.Reader)
if err != nil {
log.Error("pub", "err", err)
}
for _, sub := range matches {
sub.Chan <- err
b.UnSub(sub)
}

return err
}

type Cfg struct {
Logger *slog.Logger
PubSub PubSub
}

func PubSubMiddleware(cfg *Cfg) wish.Middleware {
return func(next ssh.Handler) ssh.Handler {
return func(sesh ssh.Session) {
args := sesh.Command()
if len(args) < 2 {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
next(sesh)
return
}

cmd := strings.TrimSpace(args[0])
channel := args[1]
logger := cfg.Logger.With(
"cmd", cmd,
"channel", channel,
)

logger.Info("running cli")

if cmd == "help" {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
} else if cmd == "sub" {
sub := &Subscriber{
Name: channel,
Session: sesh,
Chan: make(chan error),
}
err := cfg.PubSub.Sub(sub)
if err != nil {
wish.Errorln(sesh, err)
}
/* defer func() {
err = cfg.PubSub.UnSub(listener)
if err != nil {
wish.Errorln(sesh, err)
}
}() */
} else if cmd == "pub" {
msg := &Msg{
Name: channel,
Reader: sesh,
}
err := cfg.PubSub.Pub(msg)
wish.Errorln(sesh, err)
} else {
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub) {channel}")
}

next(sesh)
}
}
}

0 comments on commit 6da6679

Please sign in to comment.