diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go index 89e91b63911..226fbb7e922 100644 --- a/cmd/crowdsec-cli/clibouncer/bouncers.go +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -208,7 +208,7 @@ func (cli *cliBouncers) newListCmd() *cobra.Command { return cmd } -func (cli *cliBouncers) add(bouncerName string, key string) error { +func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error { var err error keyLength := 32 @@ -220,7 +220,7 @@ func (cli *cliBouncers) add(bouncerName string, key string) error { } } - _, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) if err != nil { return fmt.Errorf("unable to create bouncer: %w", err) } @@ -254,8 +254,8 @@ func (cli *cliBouncers) newAddCmd() *cobra.Command { cscli bouncers add MyBouncerName --key `, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args[0], key) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args[0], key) }, } @@ -304,9 +304,9 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli *cliBouncers) delete(bouncers []string, ignoreMissing bool) error { +func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { for _, bouncerID := range bouncers { - if err := cli.db.DeleteBouncer(bouncerID); err != nil { + if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil { var notFoundErr *database.BouncerNotFoundError if ignoreMissing && errors.As(err, ¬FoundErr) { return nil @@ -332,8 +332,8 @@ func (cli *cliBouncers) newDeleteCmd() *cobra.Command { Aliases: []string{"remove"}, DisableAutoGenTag: true, ValidArgsFunction: cli.validBouncerID, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args, ignoreMissing) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) }, } @@ -343,7 +343,7 @@ func (cli *cliBouncers) newDeleteCmd() *cobra.Command { return cmd } -func (cli *cliBouncers) prune(duration time.Duration, force bool) error { +func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error { if duration < 2*time.Minute { if yes, err := ask.YesNo( "The duration you provided is less than 2 minutes. "+ @@ -355,7 +355,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { } } - bouncers, err := cli.db.QueryBouncersInactiveSince(time.Now().UTC().Add(-duration)) + bouncers, err := cli.db.QueryBouncersInactiveSince(ctx, time.Now().UTC().Add(-duration)) if err != nil { return fmt.Errorf("unable to query bouncers: %w", err) } @@ -378,7 +378,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { } } - deleted, err := cli.db.BulkDeleteBouncers(bouncers) + deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers) if err != nil { return fmt.Errorf("unable to prune bouncers: %w", err) } @@ -403,8 +403,8 @@ func (cli *cliBouncers) newPruneCmd() *cobra.Command { DisableAutoGenTag: true, Example: `cscli bouncers prune -d 45m cscli bouncers prune -d 45m --force`, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, force) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, force) }, } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 9e12d27cb36..2e349a2344b 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -297,7 +297,7 @@ func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { apiKey, err := middlewares.GenerateAPIKey(keyLength) require.NoError(t, err) - _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) require.NoError(t, err) return apiKey diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 54e9b0290cc..139280ab497 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -43,6 +43,8 @@ func (c *Controller) GetDecision(gctx *gin.Context) { data []*ent.Decision ) + ctx := gctx.Request.Context() + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) @@ -73,7 +75,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } if bouncerInfo.LastPull == nil || time.Now().UTC().Sub(*bouncerInfo.LastPull) >= time.Minute { - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(ctx, time.Now().UTC(), bouncerInfo.ID); err != nil { log.Errorf("failed to update bouncer last pull: %v", err) } } @@ -370,6 +372,8 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en func (c *Controller) StreamDecision(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + streamStartTime := time.Now().UTC() bouncerInfo, err := getBouncerFromContext(gctx) @@ -400,7 +404,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { if err == nil { //Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions - if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) } } diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index e822666db0f..d438c9b15a4 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -64,6 +64,8 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + ctx := c.Request.Context() + extractedCN, err := a.TlsAuth.ValidateCert(c) if err != nil { logger.Warn(err) @@ -73,7 +75,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger = logger.WithField("cn", extractedCN) bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - bouncer, err := a.DbClient.SelectBouncerByName(bouncerName) + bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName) // This is likely not the proper way, but isNotFound does not seem to work if err != nil && strings.Contains(err.Error(), "bouncer not found") { @@ -87,7 +89,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil @@ -112,9 +114,11 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + ctx := c.Request.Context() + hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(hashStr) + bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) return nil @@ -132,6 +136,8 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { var bouncer *ent.Bouncer + ctx := c.Request.Context() + clientIP := c.ClientIP() logger := log.WithField("ip", clientIP) @@ -153,7 +159,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { logger = logger.WithField("name", bouncer.Name) if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -166,7 +172,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -182,7 +188,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go index b335738ea16..9c86cc63414 100644 --- a/pkg/apiserver/usage_metrics_test.go +++ b/pkg/apiserver/usage_metrics_test.go @@ -371,7 +371,7 @@ func TestRCMetrics(t *testing.T) { assert.Equal(t, tt.expectedStatusCode, w.Code) assert.Contains(t, w.Body.String(), tt.expectedResponse) - bouncer, _ := dbClient.SelectBouncerByName("test") + bouncer, _ := dbClient.SelectBouncerByName(ctx, "test") metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 6ff308ff786..04ef830ae72 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -41,8 +41,8 @@ func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName strin return nil } -func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) +func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx) if err != nil { return nil, err } @@ -50,8 +50,8 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) +func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx) if err != nil { return nil, err } @@ -68,14 +68,14 @@ func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { return result, nil } -func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { if ent.IsConstraintError(err) { return nil, fmt.Errorf("bouncer %s already exists", name) @@ -87,11 +87,11 @@ func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authTy return bouncer, nil } -func (c *Client) DeleteBouncer(name string) error { +func (c *Client) DeleteBouncer(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Bouncer. Delete(). Where(bouncer.NameEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } @@ -103,13 +103,13 @@ func (c *Client) DeleteBouncer(name string) error { return nil } -func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { +func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) { ids := make([]int, len(bouncers)) for i, b := range bouncers { ids[i] = b.ID } - nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err) } @@ -117,10 +117,10 @@ func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { return nbDeleted, nil } -func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { +func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error { _, err := c.Ent.Bouncer.UpdateOneID(id). SetLastPull(lastPull). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine last pull in database: %w", err) } @@ -128,8 +128,8 @@ func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { return nil } -func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(c.CTX) +func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer ip address in database: %w", err) } @@ -137,8 +137,8 @@ func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(c.CTX) +func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer type and version in database: %w", err) } @@ -146,7 +146,7 @@ func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id in return nil } -func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) { +func (c *Client) QueryBouncersInactiveSince(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) { return c.Ent.Bouncer.Query().Where( // poor man's coalesce bouncer.Or( @@ -156,5 +156,5 @@ func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) bouncer.CreatedAtLT(t), ), ), - ).All(c.CTX) + ).All(ctx) }