Skip to content

Commit

Permalink
chore: update based on pr feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
denopink committed Mar 28, 2024
1 parent 2dd2a46 commit 7303ee8
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 221 deletions.
69 changes: 54 additions & 15 deletions cmd/xmidt-agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,29 @@ type Config struct {
}

type Websocket struct {
// (optional) EnableDefaults determines whether or not to replace the zero values of the config struct `Websocket` with defaults .
EnableDefaults bool
// RegistrationAPI is the device registration url path
RegistrationAPI string
// Disable determines whether or not to disable xmidt-agent's websocket
Disable bool
// URLPath is the device registration url path
URLPath string
// AdditionalHeaders are any additional headers for the WS connection.
AdditionalHeaders http.Header
// (optional) FetchURLTimeout is the timeout for the fetching the WS url. If this is not set, the default is 30 seconds.
// FetchURLTimeout is the timeout for the fetching the WS url. If this is not set, the default is 30 seconds.
FetchURLTimeout time.Duration
// (optional) PingInterval is the ping interval allowed for the WS connection.
// PingInterval is the ping interval allowed for the WS connection.
PingInterval time.Duration
// (optional) PingTimeout is the ping timeout for the WS connection.
// PingTimeout is the ping timeout for the WS connection.
PingTimeout time.Duration
// (optional) ConnectTimeout is the connect timeout for the WS connection.
// ConnectTimeout is the connect timeout for the WS connection.
ConnectTimeout time.Duration
// (optional) KeepAliveInterval is the keep alive interval for the WS connection.
// KeepAliveInterval is the keep alive interval for the WS connection.
KeepAliveInterval time.Duration
// (optional) IdleConnTimeout is the idle connection timeout for the WS connection.
// IdleConnTimeout is the idle connection timeout for the WS connection.
IdleConnTimeout time.Duration
// (optional) TLSHandshakeTimeout is the TLS handshake timeout for the WS connection.
// TLSHandshakeTimeout is the TLS handshake timeout for the WS connection.
TLSHandshakeTimeout time.Duration
// (optional) ExpectContinueTimeout is the expect continue timeout for the WS connection.
// ExpectContinueTimeout is the expect continue timeout for the WS connection.
ExpectContinueTimeout time.Duration
// (optional) MaxMessageBytes is the largest allowable message to send or receive.
// MaxMessageBytes is the largest allowable message to send or receive.
MaxMessageBytes int64
// (optional) DisableV4 determines whether or not to allow IPv4 for the WS connection.
// If this is not set, the default is false (IPv4 is enabled).
Expand All @@ -62,9 +62,9 @@ type Websocket struct {
// If this is not set, the default is false (IPv6 is enabled).
// Either V4 or V6 can be disabled, but not both.
DisableV6 bool
// (optional) RetryPolicy sets the retry policy factory used for delaying between retry attempts for reconnection.
// RetryPolicy sets the retry policy factory used for delaying between retry attempts for reconnection.
RetryPolicy retry.Config
// (optional) Once sets whether or not to only attempt to connect once.
// Once sets whether or not to only attempt to connect once.
Once bool
}

Expand Down Expand Up @@ -251,4 +251,43 @@ var defaultConfig = Config{
},
},
},
Websocket: Websocket{
URLPath: "api/v2/device",
FetchURLTimeout: 30 * time.Second,
PingInterval: 30 * time.Second,
PingTimeout: 90 * time.Second,
ConnectTimeout: 30 * time.Second,
KeepAliveInterval: 30 * time.Second,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxMessageBytes: 256 * 1024,
/*
This retry policy gives us a very good approximation of the prior
policy. The important things about this policy are:
1. The backoff increases up to the max.
2. There is jitter that spreads the load so windows do not overlap.
iteration | parodus | this implementation
----------+-----------+----------------
0 | 0-1s | 0.666 - 1.333
1 | 1s-3s | 1.333 - 2.666
2 | 3s-7s | 2.666 - 5.333
3 | 7s-15s | 5.333 - 10.666
4 | 15s-31s | 10.666 - 21.333
5 | 31s-63s | 21.333 - 42.666
6 | 63s-127s | 42.666 - 85.333
7 | 127s-255s | 85.333 - 170.666
8 | 255s-511s | 170.666 - 341.333
9 | 255s-511s | 341.333
n | 255s-511s | 341.333
*/
RetryPolicy: retry.Config{
Interval: time.Second,
Multiplier: 2.0,
Jitter: 1.0 / 3.0,
MaxInterval: 341*time.Second + 333*time.Millisecond,
},
},
}
2 changes: 0 additions & 2 deletions cmd/xmidt-agent/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ func provideCredentials(in credsIn) (*credentials.Credentials, error) {
logger := in.Logger.Named("credentials")

opts := []credentials.Option{
// `Required` allows the xmidt-agent to send disconnect events for auth related errors
credentials.Required(),
credentials.URL(in.Creds.URL),
credentials.HTTPClient(client),
credentials.MacAddress(in.ID.DeviceID),
Expand Down
7 changes: 4 additions & 3 deletions cmd/xmidt-agent/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ type instructionsIn struct {
type instructionsOut struct {
fx.Out
JWTXT *jwtxt.Instructions
BaseURL string `name:"websocket.baseurl"`
DeviceID wrp.DeviceID `name:"websocket.DeviceID"`
DeviceID wrp.DeviceID
}

func provideInstructions(in instructionsIn) (instructionsOut, error) {
Expand Down Expand Up @@ -93,5 +92,7 @@ func provideInstructions(in instructionsIn) (instructionsOut, error) {

jwtxt, err := jwtxt.New(opts...)

return instructionsOut{JWTXT: jwtxt, BaseURL: in.Service.URL, DeviceID: in.ID.DeviceID}, err
return instructionsOut{
JWTXT: jwtxt,
DeviceID: in.ID.DeviceID}, err
}
18 changes: 16 additions & 2 deletions cmd/xmidt-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,24 +210,33 @@ func provideLogger(in LoggerIn) (*zap.Logger, error) {
}

func onStart(cred *credentials.Credentials, ws *websocket.Websocket, logger *zap.Logger) func(context.Context) error {
logger = logger.Named("on_start")

return func(ctx context.Context) error {
defer func() {
if r := recover(); nil != r {
logger.Error("stacktrace from panic", zap.String("stacktrace", string(debug.Stack())), zap.Any("panic", r))
}
}()

deadline, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond))
if ws == nil {
logger.Debug("websocket disabled")
return nil
}

ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// blocks until an attempt to fetch the credentials has been made or the context is canceled
cred.WaitUntilFetched(deadline)
cred.WaitUntilFetched(ctx)
ws.Start()

return nil
}
}

func onStop(ws *websocket.Websocket, shutdowner fx.Shutdowner, cancelList []event.CancelFunc, logger *zap.Logger) func(context.Context) error {
logger = logger.Named("on_stop")

return func(_ context.Context) error {
defer func() {
if r := recover(); nil != r {
Expand All @@ -239,6 +248,11 @@ func onStop(ws *websocket.Websocket, shutdowner fx.Shutdowner, cancelList []even
}
}()

if ws == nil {
logger.Debug("websocket disabled")
return nil
}

ws.Stop()
for _, c := range cancelList {
c()
Expand Down
124 changes: 14 additions & 110 deletions cmd/xmidt-agent/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ import (
"errors"
"fmt"
"net/url"
"reflect"
"time"

"github.com/xmidt-org/retry"
"github.com/xmidt-org/wrp-go/v3"
"github.com/xmidt-org/xmidt-agent/internal/credentials"
"github.com/xmidt-org/xmidt-agent/internal/jwtxt"
Expand All @@ -18,29 +16,14 @@ import (
"go.uber.org/zap"
)

const (
registrationAPIDefault = "api/v2/device"
expectContinueTimeoutDefault = 1 * time.Second
idleConnTimeoutDefault = 10 * time.Second
tlsHandshakeTimeoutDefault
fetchUrlTimeoutDefault = 30 * time.Second
pingIntervalDefault
connectionTimeoutDefault
keepAliveIntervalDefault
pingTimeoutDefault = 90 * time.Second
maxMessageBytesDefault = 256 * 1024
)

var (
ErrWebsocketConfig = errors.New("websocket configuration error")
)

type wsIn struct {
fx.In
// Note, BaseURL is pulled from the XmidtService configuration
BaseURL string `name:"websocket.baseurl"`
// Note, DeviceID is pulled from the Identity configuration
DeviceID wrp.DeviceID `name:"websocket.DeviceID"`
DeviceID wrp.DeviceID
Logger *zap.Logger
CLI *CLI
JWTXT *jwtxt.Instructions
Expand All @@ -55,108 +38,29 @@ type wsOut struct {
}

func provideWS(in wsIn) (wsOut, error) {
enableDefaults := in.Websocket.EnableDefaults
registrationAPI := in.Websocket.RegistrationAPI
if registrationAPI == "" && enableDefaults {
registrationAPI = registrationAPIDefault
if in.Websocket.Disable {
return wsOut{}, nil
}

opts := []websocket.Option{
websocket.DeviceID(in.DeviceID),
websocket.FetchURL(fetchURL(registrationAPI, in.JWTXT.Endpoint)),
websocket.FetchURLTimeout(in.Websocket.FetchURLTimeout),
websocket.FetchURL(fetchURL(in.Websocket.URLPath, in.JWTXT.Endpoint)),
websocket.PingInterval(in.Websocket.PingInterval),
websocket.PingTimeout(in.Websocket.PingTimeout),
websocket.ConnectTimeout(in.Websocket.ConnectTimeout),
websocket.KeepAliveInterval(in.Websocket.KeepAliveInterval),
websocket.IdleConnTimeout(in.Websocket.IdleConnTimeout),
websocket.TLSHandshakeTimeout(in.Websocket.TLSHandshakeTimeout),
websocket.ExpectContinueTimeout(in.Websocket.ExpectContinueTimeout),
websocket.MaxMessageBytes(in.Websocket.MaxMessageBytes),
websocket.CredentialsDecorator(in.Cred.Decorate),
websocket.AdditionalHeaders(in.Websocket.AdditionalHeaders),
websocket.NowFunc(time.Now),
websocket.WithIPv6(!in.Websocket.DisableV6),
websocket.WithIPv4(!in.Websocket.DisableV4),
websocket.Once(in.Websocket.Once),
}

if reflect.ValueOf(in.Websocket.RetryPolicy).IsZero() && enableDefaults {
opts = append(opts,
/*
This retry policy gives us a very good approximation of the prior
policy. The important things about this policy are:
1. The backoff increases up to the max.
2. There is jitter that spreads the load so windows do not overlap.
iteration | parodus | this implementation
----------+-----------+----------------
0 | 0-1s | 0.666 - 1.333
1 | 1s-3s | 1.333 - 2.666
2 | 3s-7s | 2.666 - 5.333
3 | 7s-15s | 5.333 - 10.666
4 | 15s-31s | 10.666 - 21.333
5 | 31s-63s | 21.333 - 42.666
6 | 63s-127s | 42.666 - 85.333
7 | 127s-255s | 85.333 - 170.666
8 | 255s-511s | 170.666 - 341.333
9 | 255s-511s | 341.333
n | 255s-511s | 341.333
*/
websocket.RetryPolicy(&retry.Config{
Interval: time.Second,
Multiplier: 2.0,
Jitter: 1.0 / 3.0,
MaxInterval: 341*time.Second + 333*time.Millisecond,
}))
} else {
opts = append(opts, websocket.RetryPolicy(in.Websocket.RetryPolicy))
}

if in.Websocket.FetchURLTimeout == 0 && enableDefaults {
opts = append(opts, websocket.FetchURLTimeout(fetchUrlTimeoutDefault))
} else {
opts = append(opts, websocket.FetchURLTimeout(in.Websocket.FetchURLTimeout))
}

if in.Websocket.PingInterval == 0 && enableDefaults {
opts = append(opts, websocket.PingInterval(pingIntervalDefault))
} else {
opts = append(opts, websocket.PingInterval(in.Websocket.PingInterval))
}

if in.Websocket.PingTimeout == 0 && enableDefaults {
opts = append(opts, websocket.PingTimeout(pingTimeoutDefault))
} else {
opts = append(opts, websocket.PingTimeout(in.Websocket.PingTimeout))
}

if in.Websocket.ConnectTimeout == 0 && enableDefaults {
opts = append(opts, websocket.ConnectTimeout(connectionTimeoutDefault))
} else {
opts = append(opts, websocket.ConnectTimeout(in.Websocket.ConnectTimeout))
}

if in.Websocket.KeepAliveInterval == 0 && enableDefaults {
opts = append(opts, websocket.KeepAliveInterval(keepAliveIntervalDefault))
} else {
opts = append(opts, websocket.KeepAliveInterval(in.Websocket.KeepAliveInterval))
}

if in.Websocket.IdleConnTimeout == 0 && enableDefaults {
opts = append(opts, websocket.IdleConnTimeout(idleConnTimeoutDefault))
} else {
opts = append(opts, websocket.IdleConnTimeout(in.Websocket.IdleConnTimeout))
}

if in.Websocket.TLSHandshakeTimeout == 0 && enableDefaults {
opts = append(opts, websocket.TLSHandshakeTimeout(tlsHandshakeTimeoutDefault))
} else {
opts = append(opts, websocket.TLSHandshakeTimeout(in.Websocket.TLSHandshakeTimeout))
}

if in.Websocket.ExpectContinueTimeout == 0 && enableDefaults {
opts = append(opts, websocket.ExpectContinueTimeout(expectContinueTimeoutDefault))
} else {
opts = append(opts, websocket.ExpectContinueTimeout(in.Websocket.ExpectContinueTimeout))
}

if in.Websocket.MaxMessageBytes == 0 && enableDefaults {
opts = append(opts, websocket.MaxMessageBytes(maxMessageBytesDefault))
} else {
opts = append(opts, websocket.MaxMessageBytes(in.Websocket.MaxMessageBytes))
websocket.RetryPolicy(in.Websocket.RetryPolicy),
}

var (
Expand Down
3 changes: 1 addition & 2 deletions internal/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/xmidt-org/wrp-go/v3"
"github.com/xmidt-org/xmidt-agent/internal/credentials/event"
"github.com/xmidt-org/xmidt-agent/internal/fs"
"github.com/xmidt-org/xmidt-agent/internal/reason"
)

var (
Expand Down Expand Up @@ -281,7 +280,7 @@ func (c *Credentials) fetch(ctx context.Context) (*xmidtInfo, time.Duration, err
resp, err := c.client.Do(req)
fe.Duration = time.Since(fe.At)
if err != nil {
fe.Err = errors.Join(fmt.Errorf("%s: %s", err, reason.GetDoErrReason(err)), ErrFetchFailed)
fe.Err = errors.Join(err, ErrFetchFailed)
return nil, 0, c.dispatch(fe)
}
defer resp.Body.Close()
Expand Down
Loading

0 comments on commit 7303ee8

Please sign in to comment.