diff --git a/sse.go b/sse.go index 8b0178a..a842ce3 100644 --- a/sse.go +++ b/sse.go @@ -9,6 +9,7 @@ package sse import ( "bytes" + "context" "encoding/json" "net/http" "strconv" @@ -27,22 +28,36 @@ type Streamer struct { connecting chan client disconnecting chan client bufSize uint + ctx context.Context } // New returns a new initialized SSE Streamer -func New() *Streamer { +func New(options ...func(*Streamer)) *Streamer { s := &Streamer{ event: make(chan []byte, 1), clients: make(map[client]bool), connecting: make(chan client), disconnecting: make(chan client), bufSize: 2, + ctx: context.Background(), + } + + for _, apply := range options { + apply(s) } s.run() return s } +// WithContext will use the provided context for the streamer. The streamer and +// any active connections will immediately close and return upon context.Done(). +func WithContext(ctx context.Context) func(s *Streamer) { + return func(s *Streamer) { + s.ctx = ctx + } +} + // run starts a goroutine to handle client connects and broadcast events. func (s *Streamer) run() { go func() { @@ -64,6 +79,12 @@ func (s *Streamer) run() { //} cl <- event } + + case <-s.ctx.Done(): + close(s.event) + close(s.connecting) + close(s.disconnecting) + return } } }() @@ -258,6 +279,9 @@ func (s *Streamer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Write events w.Write(event) // TODO: error handling fl.Flush() + + case <-s.ctx.Done(): + return } } } diff --git a/sse_test.go b/sse_test.go index 0efcc49..3fe1b52 100644 --- a/sse_test.go +++ b/sse_test.go @@ -283,3 +283,17 @@ func TestJSONErr(t *testing.T) { t.Fatal("wrong body, got:\n", w.written, "\nexpected:\n", expected) } } + +func TestWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + streamer := New(WithContext(ctx)) + w := NewMockResponseWriteFlusher() + + go func() { + time.Sleep(500 * time.Millisecond) + cancel() + }() + + streamer.ServeHTTP(w, NewMockRequestWithTimeout(5000*time.Millisecond)) +}