diff --git a/.gitignore b/.gitignore index e5df055..928e16e 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,6 @@ go.work xmidt-agent internal/jwtxt/cmd/example/* +internal/credentials/cmd/example/* *.dot diff --git a/go.mod b/go.mod index 1a05bd6..0a6fff1 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/alecthomas/kong v0.8.0 github.com/foxcpp/go-mockdns v1.0.0 github.com/golang-jwt/jwt/v5 v5.0.1-0.20230913133926-0cb4fa15e31b + github.com/google/uuid v1.3.1 github.com/goschtalt/goschtalt v0.22.1 github.com/goschtalt/yaml-decoder v0.0.1 github.com/goschtalt/yaml-encoder v0.0.3 @@ -20,7 +21,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/uuid v1.3.1 // indirect github.com/goschtalt/approx v1.0.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/miekg/dns v1.1.56 // indirect diff --git a/internal/credentials/cmd/example/main.go b/internal/credentials/cmd/example/main.go new file mode 100644 index 0000000..ba4bc53 --- /dev/null +++ b/internal/credentials/cmd/example/main.go @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" + "time" + + "github.com/alecthomas/kong" + "github.com/xmidt-org/wrp-go/v3" + cred "github.com/xmidt-org/xmidt-agent/internal/credentials" + "github.com/xmidt-org/xmidt-agent/internal/credentials/event" +) + +type CLI struct { + URL string `long:"url" help:"URL of the credential service." required:"true"` + ID string `long:"id" help:"Device ID." default:"mac:112233445566"` + Private string `long:"private" help:"mTLS private key to use."` + Public string `long:"public" help:"mTLS public key to use."` + CA string `long:"ca" help:"mTLS CA to use."` + Timeout time.Duration `long:"timeout" help:"HTTP client timeout." default:"5s"` + RedirectMax int `long:"redirect-max" help:"Maximum number of redirects to follow." default:"10"` +} + +func main() { + var cli CLI + _ = kong.Parse(&cli, + kong.Name("example"), + kong.Description("Example of using the credentials package."), + kong.UsageOnError(), + ) + + client := http.DefaultClient + + if cli.Private != "" || cli.Public != "" || cli.CA != "" { + if cli.Private == "" || cli.Public == "" || cli.CA == "" { + panic("--private, --public and --ca must be specified together") + } + + cert, err := tls.LoadX509KeyPair(cli.Public, cli.Private) + if err != nil { + panic(err) + } + + caCert, err := os.ReadFile("ca.crt") + if err != nil { + panic(err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + } + tr := &http.Transport{TLSClientConfig: tlsConfig} + + // Create an HTTP client with the custom transport + client.Transport = tr + } + + if cli.Timeout > 0 { + client.Timeout = cli.Timeout + } + + if cli.RedirectMax > 0 { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) > cli.RedirectMax { + return fmt.Errorf("stopped after %d redirects", cli.RedirectMax) + } + return nil + } + } + + credentials, err := cred.New( + cred.URL(cli.URL), + cred.MacAddress(wrp.DeviceID(cli.ID)), + cred.HTTPClient(client), + cred.SerialNumber("1234567890"), + cred.HardwareModel("model"), + cred.HardwareManufacturer("manufacturer"), + cred.FirmwareVersion("version"), + cred.LastRebootReason("reason"), + cred.XmidtProtocol("protocol"), + cred.BootRetryWait(1), + cred.AddFetchListener( + event.FetchListenerFunc(func(fe event.Fetch) { + fmt.Println("Fetch:") + fmt.Printf(" At: %s\n", fe.At.Format(time.RFC3339)) + fmt.Printf(" Duration: %s\n", fe.Duration) + fmt.Printf(" UUID: %s\n", fe.UUID) + fmt.Printf(" StatusCode: %d\n", fe.StatusCode) + fmt.Printf(" RetryIn: %s\n", fe.RetryIn) + fmt.Printf(" Expiration: %s\n", fe.Expiration.Format(time.RFC3339)) + if fe.Err != nil { + fmt.Printf(" Err: %s\n", fe.Err) + } else { + fmt.Println(" Err: nil") + } + }), + ), + ) + if err != nil { + panic(err) + } + + credentials.Start() + defer credentials.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + credentials.WaitUntilFetched(ctx) +} diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go new file mode 100644 index 0000000..03a58f9 --- /dev/null +++ b/internal/credentials/credentials.go @@ -0,0 +1,386 @@ +// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package credentials + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "sync" + "time" + + "github.com/google/uuid" + "github.com/xmidt-org/eventor" + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/credentials/event" +) + +var ( + ErrInvalidInput = fmt.Errorf("invalid input") + ErrNilRequest = fmt.Errorf("nil request") + ErrNoToken = fmt.Errorf("no token") + ErrTokenExpired = fmt.Errorf("token expired") + ErrFetchNotAttempted = fmt.Errorf("fetch not attempted") + ErrFetchFailed = fmt.Errorf("fetch failed") +) + +const ( + DefaultRefetchPercent = 90.0 +) + +/* +Notes: + - The network interface is set via the http.Client. + - If v4, v6 or both are desired, it is set via the http.Client. + - The timeout is set via the http.Client. + - The maximum redirect count is set via the http.Client. + - mTLS is set via the http.Client. + - The TLS version is set via the http.Client. +*/ +type Credentials struct { + m sync.RWMutex + wg sync.WaitGroup + shutdown context.CancelFunc + fetched chan struct{} + valid chan struct{} + wakeup chan chan struct{} + nowFunc func() time.Time + fetchListeners eventor.Eventor[event.FetchListener] + decorateListeners eventor.Eventor[event.DecorateListener] + + // What we are using to fetch the credentials. + + url string + refetchPercent float64 + assumedLifetime time.Duration + client *http.Client + macAddress wrp.DeviceID + serialNumber string + hardwareModel string + hardwareManufacturer string + firmwareVersion string + lastRebootReason string + xmidtProtocol string + bootRetryWait time.Duration + lastReconnectReason func() string // dynamic + partnerID func() string // dynamic + + // What we are using to decorate the request. + token *xmidtToken +} + +// Option is the interface implemented by types that can be used to +// configure the credentials. +type Option interface { + apply(*Credentials) error +} + +// New creates a new credentials service object. +func New(opts ...Option) (*Credentials, error) { + required := []Option{ + urlVador(), + macAddressVador(), + serialNumberVador(), + hardwareModelVador(), + hardwareManufacturerVador(), + firmwareVersionVador(), + lastRebootReasonVador(), + xmidtProtocolVador(), + bootRetryWaitVador(), + } + + c := Credentials{ + client: http.DefaultClient, + fetched: make(chan struct{}), + valid: make(chan struct{}), + wakeup: make(chan chan struct{}), + nowFunc: time.Now, + refetchPercent: DefaultRefetchPercent, + lastReconnectReason: func() string { return "" }, + partnerID: func() string { return "" }, + } + + opts = append(opts, required...) + + for _, opt := range opts { + if opt == nil { + continue + } + + err := opt.apply(&c) + if err != nil { + return nil, err + } + } + + return &c, nil +} + +// Start starts the credentials service. +func (c *Credentials) Start() { + c.m.Lock() + defer c.m.Unlock() + + if c.shutdown != nil { + return + } + + var ctx context.Context + ctx, c.shutdown = context.WithCancel(context.Background()) + + go c.run(ctx) +} + +// Stop stops the credentials service. +func (c *Credentials) Stop() { + c.m.Lock() + shudown := c.shutdown + c.m.Unlock() + + if shudown != nil { + shudown() + } + c.wg.Wait() +} + +// WaitUntilFetched blocks until an attempt to fetch the credentials has been +// made or the context is canceled. +func (c *Credentials) WaitUntilFetched(ctx context.Context) { + // Fetched is never re-created, so we don't need to lock. + select { + case <-c.fetched: + case <-ctx.Done(): + } +} + +// WaitUntilValid blocks until the credentials are valid or the context is +// canceled. +func (c *Credentials) WaitUntilValid(ctx context.Context) { + c.m.RLock() + valid := c.valid + c.m.RUnlock() + + select { + case <-valid: + case <-ctx.Done(): + } +} + +// MarkInvalid marks the credentials as invalid and causes the service to +// immediately attempt to fetch new credentials. +func (c *Credentials) MarkInvalid(ctx context.Context) { + ch := make(chan struct{}) + + select { + case c.wakeup <- ch: + select { + case <-ch: + case <-ctx.Done(): + } + case <-ctx.Done(): + } + +} + +// Decorate decorates the request with the credentials. If the credentials +// are not valid, an error is returned. +func (c *Credentials) Decorate(req *http.Request) error { + var e event.Decorate + + if req == nil { + e.Err = ErrNilRequest + return c.dispatch(e) + } + + var token string + var expiresAt time.Time + + c.m.RLock() + if c.token != nil { + token = c.token.Token + expiresAt = c.token.ExpiresAt + } + c.m.RUnlock() + + if token == "" { + e.Err = ErrNoToken + return c.dispatch(e) + } + + e.Expiration = expiresAt + if c.nowFunc().After(expiresAt) { + e.Err = ErrTokenExpired + return c.dispatch(e) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return c.dispatch(e) +} + +// fetch fetches the credentials from the server. This should only be called +// by the run() method. +func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, error) { + var fe event.Fetch + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) + if err != nil { + fe.Err = errors.Join(err, ErrFetchNotAttempted) + return nil, 0, c.dispatch(fe) + } + + tid, err := uuid.NewRandom() + if err != nil { + fe.Err = errors.Join(err, ErrFetchNotAttempted) + return nil, 0, c.dispatch(fe) + } + + fe.UUID = tid + + req.Header.Set("X-Midt-Boot-Retry-Wait", c.bootRetryWait.String()) + req.Header.Set("X-Midt-Mac-Address", c.macAddress.ID()) + req.Header.Set("X-Midt-Serial-Number", c.serialNumber) + req.Header.Set("X-Midt-Uuid", tid.String()) + req.Header.Set("X-Midt-Partner-Id", c.partnerID()) + req.Header.Set("X-Midt-Hardware-Model", c.hardwareModel) + req.Header.Set("X-Midt-Hardware-Manufacturer", c.hardwareManufacturer) + req.Header.Set("X-Midt-Firmware-Name", c.firmwareVersion) + req.Header.Set("X-Midt-Protocol", c.xmidtProtocol) + req.Header.Set("X-Midt-Last-Reboot-Reason", c.lastRebootReason) + req.Header.Set("X-Midt-Last-Reconnect-Reason", c.lastReconnectReason()) + + fe.At = time.Now() + resp, err := c.client.Do(req) + fe.Duration = time.Since(fe.At) + if err != nil { + fe.Err = errors.Join(err, ErrFetchFailed) + return nil, 0, c.dispatch(fe) + } + defer resp.Body.Close() + + fe.StatusCode = resp.StatusCode + if resp.StatusCode != http.StatusOK { + var retryIn time.Duration + if resp.StatusCode == http.StatusTooManyRequests { + if after, err := strconv.Atoi(resp.Header.Get("Retry-After")); err == nil { + retryIn = time.Duration(after) * time.Second + } + } + + fe.RetryIn = retryIn + fe.Err = errors.Join(err, ErrFetchFailed) + return nil, retryIn, c.dispatch(fe) + } + + var token xmidtToken + body, err := io.ReadAll(resp.Body) + if err != nil { + fe.Err = errors.Join(err, ErrFetchFailed) + return nil, 0, c.dispatch(fe) + } + token.Token = string(body) + + // One hundred years is forever. + token.ExpiresAt = c.nowFunc().Add(time.Hour * 24 * 365 * 100) + if c.assumedLifetime > 0 { + // If we have an assumed lifetime, use it. + token.ExpiresAt = c.nowFunc().Add(c.assumedLifetime) + } + + if expiration, err := http.ParseTime(resp.Header.Get("Expires")); err == nil { + // Even better, we were told when it expires. + token.ExpiresAt = expiration + } + + fe.Expiration = token.ExpiresAt + + return &token, 0, c.dispatch(fe) +} + +// run is the main loop for the credentials service. +func (c *Credentials) run(ctx context.Context) { + var ( + timer *time.Timer + fetched bool + valid bool + ) + + c.wg.Add(1) + defer c.wg.Done() + + for { + token, retryIn, err := c.fetch(ctx) + if !fetched { + close(c.fetched) + fetched = true + } + + // Assume we failed, so retry in 1 second or when the server suggested. + next := max(time.Second, retryIn) + + if err == nil && token != nil { + expires := token.ExpiresAt + + c.m.Lock() + c.token = token + c.m.Unlock() + + if !valid { + close(c.valid) + valid = true + } + + until := expires.Sub(c.nowFunc()) + if 0 < until { + // Add a timer to fetch the token again + next = time.Duration(float64(until) * c.refetchPercent / 100.0) + } + } + + timer = time.NewTimer(next) + defer timer.Stop() + + select { + case ch := <-c.wakeup: + if valid { + c.m.Lock() + c.valid = make(chan struct{}) + valid = false + c.m.Unlock() + } + ch <- struct{}{} + case <-timer.C: + case <-ctx.Done(): + return + } + } +} + +// dispatch dispatches the event to the listeners and returns the error that +// should be returned by the caller. +func (c *Credentials) dispatch(evnt any) error { + switch evnt := evnt.(type) { + case event.Fetch: + c.fetchListeners.Visit(func(listener event.FetchListener) { + listener.OnFetch(evnt) + }) + return evnt.Err + case event.Decorate: + c.decorateListeners.Visit(func(listener event.DecorateListener) { + listener.OnDecorate(evnt) + }) + return evnt.Err + } + + panic("unknown event type") +} + +// xmidtToken is the token returned from the server as well as the expiration +// time. +type xmidtToken struct { + Token string + ExpiresAt time.Time +} diff --git a/internal/credentials/credentials_test.go b/internal/credentials/credentials_test.go new file mode 100644 index 0000000..d5df461 --- /dev/null +++ b/internal/credentials/credentials_test.go @@ -0,0 +1,550 @@ +// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package credentials + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/credentials/event" +) + +func TestNew(t *testing.T) { + testClient := &http.Client{} + + simplest := []Option{ + URL("http://example.com"), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + } + + tests := []struct { + description string + opt Option + opts []Option + expectedErr error + check func(*assert.Assertions, *Credentials) + checks []func(*assert.Assertions, *Credentials) + optStr string + }{ + { + description: "nil option", + expectedErr: ErrInvalidInput, + }, { + description: "simplest config", + opts: simplest, + check: func(assert *assert.Assertions, c *Credentials) { + assert.Equal("http://example.com", c.url) + assert.Equal(wrp.DeviceID("mac:112233445566"), c.macAddress) + assert.Equal("1234567890", c.serialNumber) + assert.Equal("model", c.hardwareModel) + assert.Equal("manufacturer", c.hardwareManufacturer) + assert.Equal("version", c.firmwareVersion) + assert.Equal("reason", c.lastRebootReason) + assert.Equal("protocol", c.xmidtProtocol) + assert.Equal(time.Duration(1), c.bootRetryWait) + assert.Empty(c.partnerID()) + assert.Empty(c.lastReconnectReason()) + }, + }, { + description: "common config", + opts: append(simplest, []Option{ + HTTPClient(testClient), + RefetchPercent(50.0), + PartnerID(func() string { return "partner" }), + AssumedLifetime(24 * time.Hour), + LastReconnectReason(func() string { return "reconnect_reason" }), + }...), + check: func(assert *assert.Assertions, c *Credentials) { + assert.Equal("http://example.com", c.url) + assert.Equal(wrp.DeviceID("mac:112233445566"), c.macAddress) + assert.Equal("1234567890", c.serialNumber) + assert.Equal("model", c.hardwareModel) + assert.Equal("manufacturer", c.hardwareManufacturer) + assert.Equal("version", c.firmwareVersion) + assert.Equal("reason", c.lastRebootReason) + assert.Equal("protocol", c.xmidtProtocol) + assert.Equal(time.Duration(1), c.bootRetryWait) + assert.Equal(testClient, c.client) + assert.Equal(50.0, c.refetchPercent) + assert.Equal("partner", c.partnerID()) + assert.Equal(24*time.Hour, c.assumedLifetime) + assert.Equal("reconnect_reason", c.lastReconnectReason()) + }, + }, { + description: "invalid url", + opts: append(simplest, []Option{ + URL(""), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid mac address", + opts: append(simplest, []Option{ + MacAddress(wrp.DeviceID("")), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid serial number", + opts: append(simplest, []Option{ + SerialNumber(""), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid hardware model", + opts: append(simplest, []Option{ + HardwareModel(""), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid hardware manufacturer", + + opts: append(simplest, []Option{ + HardwareManufacturer(""), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid firmware version", + opts: append(simplest, []Option{ + FirmwareVersion(""), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid last reboot reason", + opts: append(simplest, []Option{ + LastRebootReason(""), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid xmidt protocol", + opts: append(simplest, []Option{ + XmidtProtocol(""), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid boot retry wait", + opts: append(simplest, []Option{ + BootRetryWait(0), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "refetch percent (default)", + opts: append(simplest, []Option{ + RefetchPercent(0.0), + }...), + check: func(assert *assert.Assertions, c *Credentials) { + assert.Equal(DefaultRefetchPercent, c.refetchPercent) + }, + }, { + description: "invalid refetch percent (low)", + opts: append(simplest, []Option{ + RefetchPercent(-1.0), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid refetch percent (high)", + opts: append(simplest, []Option{ + RefetchPercent(100.1), + }...), + expectedErr: ErrInvalidInput, + }, { + description: "invalid http client", + opts: append(simplest, []Option{ + HTTPClient(nil), + }...), + check: func(assert *assert.Assertions, c *Credentials) { + assert.NotNil(c.client) + }, + }, { + description: "invalid partner id", + opts: append(simplest, []Option{ + PartnerID(nil), + }...), + check: func(assert *assert.Assertions, c *Credentials) { + assert.NotNil(c.partnerID) + }, + }, { + description: "invalid last reconnect reason", + opts: append(simplest, []Option{ + LastReconnectReason(nil), + }...), + check: func(assert *assert.Assertions, c *Credentials) { + assert.NotNil(c.lastReconnectReason) + }, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + + opts := append(tc.opts, tc.opt) + + got, err := New(opts...) + + checks := append(tc.checks, tc.check) + for _, check := range checks { + if check != nil { + check(assert, got) + } + } + + if tc.expectedErr == nil { + assert.NotNil(got) + assert.NoError(err) + return + } + + assert.Nil(got) + assert.ErrorIs(err, tc.expectedErr) + }) + } +} + +func TestEndToEnd429(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + + w.Header().Add("Retry-After", "15") + w.WriteHeader(http.StatusTooManyRequests) + }, + ), + ) + defer server.Close() + + var called int + c, err := New( + URL(server.URL), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + assert.Equal(15*time.Second, e.RetryIn) + assert.ErrorIs(e.Err, ErrFetchFailed) + called++ + })), + ) + + require.NoError(err) + require.NotNil(c) + + c.Start() + defer c.Stop() + + ctx := context.Background() + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + c.WaitUntilFetched(deadline) + assert.Equal(1, called) +} + +func TestEndToEndWithExpires(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + when := time.Now().Add(1 * time.Hour) + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + + w.Header().Add("Expires", when.Format(http.TimeFormat)) + _, _ = w.Write([]byte(`token`)) + }, + ), + ) + defer server.Close() + + c, err := New( + URL(server.URL), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + assert.Equal(when.Format(http.TimeFormat), e.Expiration.Format(http.TimeFormat)) + assert.NoError(e.Err) + })), + ) + + require.NoError(err) + require.NotNil(c) + + c.Start() + defer c.Stop() + + ctx := context.Background() + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + c.WaitUntilValid(deadline) +} + +func TestEndToEndMarkInvalid(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + counter := 1 + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + + _, _ = w.Write([]byte(`token`)) + counter++ + }, + ), + ) + defer server.Close() + + called := 0 + c, err := New( + URL(server.URL), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + fmt.Println("Fetch:") + fmt.Printf(" At: %s\n", e.At.Format(time.RFC3339Nano)) + fmt.Printf(" Duration: %s\n", e.Duration) + fmt.Printf(" UUID: %s\n", e.UUID) + fmt.Printf(" StatusCode: %d\n", e.StatusCode) + fmt.Printf(" RetryIn: %s\n", e.RetryIn) + fmt.Printf(" Expiration: %s\n", e.Expiration.Format(time.RFC3339)) + assert.NoError(e.Err) + called++ + })), + ) + + require.NoError(err) + require.NotNil(c) + + c.Start() + defer c.Stop() + + ctx := context.Background() + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + c.WaitUntilValid(deadline) + + c.MarkInvalid(deadline) + + c.WaitUntilValid(deadline) + + assert.Equal(2, called) +} + +func TestEndToEnd(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + + _, _ = w.Write([]byte(`token`)) + }, + ), + ) + defer server.Close() + + c, err := New( + URL(server.URL), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + AssumedLifetime(24*time.Hour), + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + fmt.Println("Fetch:") + fmt.Printf(" At: %s\n", e.At.Format(time.RFC3339)) + fmt.Printf(" Duration: %s\n", e.Duration) + fmt.Printf(" UUID: %s\n", e.UUID) + fmt.Printf(" StatusCode: %d\n", e.StatusCode) + fmt.Printf(" RetryIn: %s\n", e.RetryIn) + fmt.Printf(" Expiration: %s\n", e.Expiration.Format(time.RFC3339)) + if e.Err != nil { + fmt.Printf(" Err: %s\n", e.Err) + } else { + fmt.Println(" Err: nil") + } + })), + AddDecorateListener(event.DecorateListenerFunc( + func(e event.Decorate) { + fmt.Println("Decorate:") + fmt.Printf(" Expiration: %s\n", e.Expiration.Format(time.RFC3339)) + if e.Err != nil { + fmt.Printf(" Err: %s\n", e.Err) + } else { + fmt.Println(" Err: nil") + } + })), + ) + + require.NoError(err) + require.NotNil(c) + + c.Start() + + // Multiple calls to Start is ok. + c.Start() + + ctx := context.Background() + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + c.WaitUntilFetched(deadline) + c.WaitUntilValid(deadline) + cancel() + + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + assert.NoError(err) + assert.NotNil(req) + + err = c.Decorate(req) + assert.NoError(err) + assert.Equal("Bearer token", strings.TrimSpace(req.Header.Get("Authorization"))) + + // Decorate the a second time. + _ = c.Decorate(req) + + c.Stop() + + // Multiple calls to Stop is ok. + c.Stop() +} + +func TestContextExpires(t *testing.T) { + c, err := New( + URL("http://example.com"), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + ) + + require.NoError(t, err) + require.NotNil(t, c) + + ctx := context.Background() + deadline, cancel := context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel() + c.WaitUntilFetched(deadline) + + deadline, cancel = context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel() + c.WaitUntilValid(deadline) + + deadline, cancel = context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel() + c.MarkInvalid(deadline) +} + +func TestDecorate(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + + _, _ = w.Write([]byte(``)) + }, + ), + ) + defer server.Close() + + var count int + c, err := New( + URL(server.URL), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + assert.NoError(e.Err) + })), + AddDecorateListener(event.DecorateListenerFunc( + func(e event.Decorate) { + switch count { + case 0: + assert.ErrorIs(e.Err, ErrNilRequest) + case 1: + assert.ErrorIs(e.Err, ErrNoToken) + default: + assert.Fail("too many calls to decorate") + } + count++ + })), + ) + + require.NoError(err) + require.NotNil(c) + + c.Start() + defer c.Stop() + + ctx := context.Background() + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + c.WaitUntilFetched(deadline) + + err = c.Decorate(nil) + assert.ErrorIs(err, ErrNilRequest) + + req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil) + err = c.Decorate(req) + assert.ErrorIs(err, ErrNoToken) + + assert.Equal(2, count) +} diff --git a/internal/credentials/event/events.go b/internal/credentials/event/events.go new file mode 100644 index 0000000..91607f1 --- /dev/null +++ b/internal/credentials/event/events.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package event + +import ( + "time" + + "github.com/google/uuid" +) + +// CancelListenerFunc is the interface that provides a method to cancel +// a listener. +type CancelListenerFunc func() + +// Fetch is the event that is sent when the credentials are fetched. +type Fetch struct { + // At holds the time when the fetch request was made. + At time.Time + + // Duration is the time waited for the token/response. + Duration time.Duration + + // UUID is the UUID of the request. + UUID uuid.UUID + + // StatusCode is the status code returned from the SAT service. + StatusCode int + + // RetryIn is the time to wait before retrying the request. Any value + // less than or equal to zero means the server did not specify a + // recommended retry time. + RetryIn time.Duration + + // Expiration is the time the token expires. + Expiration time.Time + + // Error is the error returned from the SAT service. + Err error +} + +// FetchListener is the interface that must be implemented by types that +// want to receive Fetch notifications. +type FetchListener interface { + OnFetch(Fetch) +} + +// FetchListenerFunc is a function type that implements FetchListener. +// It can be used as an adapter for functions that need to implement the +// FetchListener interface. +type FetchListenerFunc func(Fetch) + +func (f FetchListenerFunc) OnFetch(e Fetch) { + f(e) +} + +// Decorate is the event that is sent when the request is decorated. +type Decorate struct { + // Expiration is the time the token expires. + Expiration time.Time + + // Error is the error returned from the SAT service. + Err error +} + +// DecorateListener is the interface that must be implemented by types that +// want to receive Decorate notifications. +type DecorateListener interface { + OnDecorate(Decorate) +} + +// DecorateListenerFunc is a function type that implements DecorateListener. +// It can be used as an adapter for functions that need to implement the +// DecorateListener interface. +type DecorateListenerFunc func(Decorate) + +func (f DecorateListenerFunc) OnDecorate(e Decorate) { + f(e) +} diff --git a/internal/credentials/internal_options.go b/internal/credentials/internal_options.go new file mode 100644 index 0000000..1294284 --- /dev/null +++ b/internal/credentials/internal_options.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package credentials + +import "fmt" + +func urlVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.url == "" { + return fmt.Errorf("%w URL is missing", ErrInvalidInput) + } + return nil + }) +} + +func macAddressVador() Option { + return optionFunc( + func(c *Credentials) error { + if len(c.macAddress) == 0 { + return fmt.Errorf("%w mac address is missing", ErrInvalidInput) + } + return nil + }) +} + +func serialNumberVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.serialNumber == "" { + return fmt.Errorf("%w serial number is missing", ErrInvalidInput) + } + return nil + }) +} + +func hardwareModelVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.hardwareModel == "" { + return fmt.Errorf("%w hardware model is missing", ErrInvalidInput) + } + return nil + }) +} + +func hardwareManufacturerVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.hardwareManufacturer == "" { + return fmt.Errorf("%w hardware manufacturer is missing", ErrInvalidInput) + } + return nil + }) +} + +func firmwareVersionVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.firmwareVersion == "" { + return fmt.Errorf("%w firmware version is missing", ErrInvalidInput) + } + return nil + }) +} + +func lastRebootReasonVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.lastRebootReason == "" { + return fmt.Errorf("%w last reboot reason is missing", ErrInvalidInput) + } + return nil + }) +} + +func xmidtProtocolVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.xmidtProtocol == "" { + return fmt.Errorf("%w xmidt protocol is missing", ErrInvalidInput) + } + return nil + }) +} + +func bootRetryWaitVador() Option { + return optionFunc( + func(c *Credentials) error { + if c.bootRetryWait == 0 { + return fmt.Errorf("%w boot retry wait is missing", ErrInvalidInput) + } + return nil + }) +} diff --git a/internal/credentials/options.go b/internal/credentials/options.go new file mode 100644 index 0000000..57874d9 --- /dev/null +++ b/internal/credentials/options.go @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package credentials + +import ( + "net/http" + "time" + + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/credentials/event" +) + +type optionFunc func(*Credentials) error + +func (f optionFunc) apply(c *Credentials) error { + return f(c) +} + +type nilOptionFunc func(*Credentials) + +func (f nilOptionFunc) apply(c *Credentials) error { + f(c) + return nil +} + +// URL is the URL of the credential service. +func URL(url string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.url = url + }) +} + +// HTTPClient is the HTTP client used to fetch the credentials. +func HTTPClient(client *http.Client) Option { + return nilOptionFunc( + func(c *Credentials) { + if client == nil { + client = http.DefaultClient + } + c.client = client + }) +} + +// RefetchPercent is the percentage of the lifetime of the credentials +// that must pass before a refetch is attempted. The accepted range is 0.0 to +// 100.0. If 0.0 is specified the default is used. The default is 90.0. +func RefetchPercent(percent float64) Option { + return optionFunc( + func(c *Credentials) error { + if percent < 0.0 || percent > 100.0 { + return ErrInvalidInput + } + + c.refetchPercent = percent + + if c.refetchPercent == 0.0 { + c.refetchPercent = DefaultRefetchPercent + } + return nil + }) +} + +// AssumedLifetime is the lifetime of the credentials that is assumed if the +// credentials service does not return a lifetime. A value of zero means that +// no assumed lifetime is used. The default is zero. +func AssumedLifetime(lifetime time.Duration) Option { + return nilOptionFunc( + func(c *Credentials) { + c.assumedLifetime = lifetime + }) +} + +// MacAddress is the MAC address of the device. +func MacAddress(macAddress wrp.DeviceID) Option { + return nilOptionFunc( + func(c *Credentials) { + c.macAddress = macAddress + }) +} + +// SerialNumber is the serial number of the device. +func SerialNumber(serialNumber string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.serialNumber = serialNumber + }) +} + +// HardwareModel is the hardware model of the device. +func HardwareModel(hardwareModel string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.hardwareModel = hardwareModel + }) +} + +// HardwareManufacturer is the hardware manufacturer of the device. +func HardwareManufacturer(hardwareManufacturer string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.hardwareManufacturer = hardwareManufacturer + }) +} + +// FirmwareVersion is the firmware version of the device. +func FirmwareVersion(firmwareVersion string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.firmwareVersion = firmwareVersion + }) +} + +// LastRebootReason is the reason for the most recent reboot of the device. +func LastRebootReason(lastRebootReason string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.lastRebootReason = lastRebootReason + }) +} + +// XmidtProtocol is the protocol version used by the device to communicate with +// the Xmidt cluster. +func XmidtProtocol(xmidtProtocol string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.xmidtProtocol = xmidtProtocol + }) +} + +// BootRetryWait is the time to wait before retrying the request. Any value +// less than or equal to zero is treated as zero. +func BootRetryWait(bootRetryWait time.Duration) Option { + return nilOptionFunc( + func(c *Credentials) { + c.bootRetryWait = bootRetryWait + }) +} + +// LastReconnectReason is the reason for the most recent reconnect of the +// device. This is a dynamic value that is obtained by calling the function +// provided. +func LastReconnectReason(lastReconnectReason func() string) Option { + return nilOptionFunc( + func(c *Credentials) { + if lastReconnectReason == nil { + lastReconnectReason = func() string { return "" } + } + c.lastReconnectReason = lastReconnectReason + }) +} + +// PartnerID is the partner ID of the device. This is a dynamic value that is +// obtained by calling the function provided. +func PartnerID(partnerID func() string) Option { + return nilOptionFunc( + func(c *Credentials) { + if partnerID == nil { + partnerID = func() string { return "" } + } + c.partnerID = partnerID + }) +} + +// NowFunc is the function used to obtain the current time. +func NowFunc(nowFunc func() time.Time) Option { + return nilOptionFunc( + func(c *Credentials) { + if nowFunc == nil { + nowFunc = time.Now + } + c.nowFunc = nowFunc + }) +} + +// AddFetchListener adds a listener for fetch events. If the optional cancel +// parameter is provided, it is set to a function that can be used to cancel +// the listener. +func AddFetchListener(listener event.FetchListener, cancel ...*event.CancelListenerFunc) Option { + return nilOptionFunc( + func(c *Credentials) { + cncl := c.fetchListeners.Add(listener) + if len(cancel) > 0 && cancel[0] != nil { + *cancel[0] = event.CancelListenerFunc(cncl) + } + }) +} + +// AddDecorateListener adds a listener for decorate events. If the optional +// cancel parameter is provided, it is set to a function that can be used to +// cancel the listener. +func AddDecorateListener(listener event.DecorateListener, cancel ...*event.CancelListenerFunc) Option { + return nilOptionFunc( + func(c *Credentials) { + cncl := c.decorateListeners.Add(listener) + if len(cancel) > 0 && cancel[0] != nil { + *cancel[0] = event.CancelListenerFunc(cncl) + } + }) +}