diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index 1ad5dc3..de4af8d 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -228,7 +228,6 @@ func (ws *Websocket) run(ctx context.Context) { mode := ws.nextMode(ipv4) policy := ws.retryPolicyFactory.NewPolicy(ctx) - inactivityTimeout := time.After(ws.inactivityTimeout) for { var next time.Duration @@ -258,6 +257,7 @@ func (ws *Websocket) run(ctx context.Context) { // Store the connection so writing can take place. ws.m.Lock() ws.conn = conn + activity := make(chan struct{}) ws.conn.SetPingListener((func(ctx context.Context, b []byte) { if ctx.Err() != nil { return @@ -270,7 +270,9 @@ func (ws *Websocket) run(ctx context.Context) { }) }) - inactivityTimeout = time.After(ws.inactivityTimeout) + if len(activity) == 0 { + activity <- struct{}{} + } })) ws.conn.SetPongListener(func(ctx context.Context, b []byte) { if ctx.Err() != nil { @@ -289,22 +291,32 @@ func (ws *Websocket) run(ctx context.Context) { // Read loop for { var msg wrp.Message - ctx, cancel := context.WithTimeout(ctx, ws.inactivityTimeout) - typ, reader, err := ws.conn.Reader(ctx) - if errors.Is(err, context.DeadlineExceeded) { - select { - case <-inactivityTimeout: - // inactivityTimeout occurred, continue with ws.read()'s error handling (connection will be closed). - default: - // Ping was received during ws.conn.Reader(), i.e.: inactivityTimeout was reset. - // Reset inactivityTimeout again for the next ws.conn.Reader(). - inactivityTimeout = time.After(ws.inactivityTimeout) - cancel() - continue + ctx, cancel := context.WithCancelCause(ctx) + + // Monitor for activity. + go func() { + inactivityTimeout := time.After(ws.inactivityTimeout) + loop1: + for { + select { + case <-ctx.Done(): + break loop1 + case <-activity: + inactivityTimeout = time.After(ws.inactivityTimeout) + case <-inactivityTimeout: + // inactivityTimeout occurred, cancel the context. + cancel(context.DeadlineExceeded) + break loop1 + } } - } else if errors.Is(err, context.Canceled) { - // Parent context has been canceled. - cancel() + }() + + typ, reader, err := ws.conn.Reader(ctx) + ctxErr := context.Cause(ctx) + err = errors.Join(err, ctxErr) + // If ctxErr is context.Canceled then the parent context has been canceled. + if errors.Is(ctxErr, context.Canceled) { + cancel(nil) break } @@ -318,7 +330,7 @@ func (ws *Websocket) run(ctx context.Context) { } // Cancel ws.conn.Reader()'s context after wrp decoding. - cancel() + cancel(nil) if err != nil { ws.m.Lock() ws.conn = nil