Skip to content

Commit

Permalink
context propagation: pkg/database/bouncers
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 18, 2024
1 parent 94f3abe commit b3109de
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 41 deletions.
26 changes: 13 additions & 13 deletions cmd/crowdsec-cli/clibouncer/bouncers.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (cli *cliBouncers) newListCmd() *cobra.Command {
return cmd
}

func (cli *cliBouncers) add(bouncerName string, key string) error {
func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error {
var err error

keyLength := 32
Expand All @@ -220,7 +220,7 @@ func (cli *cliBouncers) add(bouncerName string, key string) error {
}
}

_, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType)
_, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType)
if err != nil {
return fmt.Errorf("unable to create bouncer: %w", err)
}
Expand Down Expand Up @@ -254,8 +254,8 @@ func (cli *cliBouncers) newAddCmd() *cobra.Command {
cscli bouncers add MyBouncerName --key <random-key>`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, args []string) error {
return cli.add(args[0], key)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.add(cmd.Context(), args[0], key)
},
}

Expand Down Expand Up @@ -304,9 +304,9 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp
return ret, cobra.ShellCompDirectiveNoFileComp
}

func (cli *cliBouncers) delete(bouncers []string, ignoreMissing bool) error {
func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error {
for _, bouncerID := range bouncers {
if err := cli.db.DeleteBouncer(bouncerID); err != nil {
if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil {
var notFoundErr *database.BouncerNotFoundError
if ignoreMissing && errors.As(err, &notFoundErr) {
return nil
Expand All @@ -332,8 +332,8 @@ func (cli *cliBouncers) newDeleteCmd() *cobra.Command {
Aliases: []string{"remove"},
DisableAutoGenTag: true,
ValidArgsFunction: cli.validBouncerID,
RunE: func(_ *cobra.Command, args []string) error {
return cli.delete(args, ignoreMissing)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.delete(cmd.Context(), args, ignoreMissing)
},
}

Expand All @@ -343,7 +343,7 @@ func (cli *cliBouncers) newDeleteCmd() *cobra.Command {
return cmd
}

func (cli *cliBouncers) prune(duration time.Duration, force bool) error {
func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error {
if duration < 2*time.Minute {
if yes, err := ask.YesNo(
"The duration you provided is less than 2 minutes. "+
Expand All @@ -355,7 +355,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error {
}
}

bouncers, err := cli.db.QueryBouncersInactiveSince(time.Now().UTC().Add(-duration))
bouncers, err := cli.db.QueryBouncersInactiveSince(ctx, time.Now().UTC().Add(-duration))
if err != nil {
return fmt.Errorf("unable to query bouncers: %w", err)
}
Expand All @@ -378,7 +378,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error {
}
}

deleted, err := cli.db.BulkDeleteBouncers(bouncers)
deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers)
if err != nil {
return fmt.Errorf("unable to prune bouncers: %w", err)
}
Expand All @@ -403,8 +403,8 @@ func (cli *cliBouncers) newPruneCmd() *cobra.Command {
DisableAutoGenTag: true,
Example: `cscli bouncers prune -d 45m
cscli bouncers prune -d 45m --force`,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.prune(duration, force)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.prune(cmd.Context(), duration, force)
},
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
apiKey, err := middlewares.GenerateAPIKey(keyLength)
require.NoError(t, err)

_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
_, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
require.NoError(t, err)

return apiKey
Expand Down
8 changes: 6 additions & 2 deletions pkg/apiserver/controllers/v1/decisions.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ func (c *Controller) GetDecision(gctx *gin.Context) {
data []*ent.Decision
)

ctx := gctx.Request.Context()

bouncerInfo, err := getBouncerFromContext(gctx)
if err != nil {
gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"})
Expand Down Expand Up @@ -73,7 +75,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) {
}

if bouncerInfo.LastPull == nil || time.Now().UTC().Sub(*bouncerInfo.LastPull) >= time.Minute {
if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil {
if err := c.DBClient.UpdateBouncerLastPull(ctx, time.Now().UTC(), bouncerInfo.ID); err != nil {
log.Errorf("failed to update bouncer last pull: %v", err)
}
}
Expand Down Expand Up @@ -370,6 +372,8 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
func (c *Controller) StreamDecision(gctx *gin.Context) {
var err error

ctx := gctx.Request.Context()

streamStartTime := time.Now().UTC()

bouncerInfo, err := getBouncerFromContext(gctx)
Expand Down Expand Up @@ -400,7 +404,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {

if err == nil {
//Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions
if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil {
if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil {
log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err)
}
}
Expand Down
18 changes: 12 additions & 6 deletions pkg/apiserver/middlewares/v1/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
return nil
}

ctx := c.Request.Context()

extractedCN, err := a.TlsAuth.ValidateCert(c)
if err != nil {
logger.Warn(err)
Expand All @@ -73,7 +75,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
logger = logger.WithField("cn", extractedCN)

bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
bouncer, err := a.DbClient.SelectBouncerByName(bouncerName)
bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName)

// This is likely not the proper way, but isNotFound does not seem to work
if err != nil && strings.Contains(err.Error(), "bouncer not found") {
Expand All @@ -87,7 +89,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {

logger.Infof("Creating bouncer %s", bouncerName)

bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
if err != nil {
logger.Errorf("while creating bouncer db entry: %s", err)
return nil
Expand All @@ -112,9 +114,11 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer {
return nil
}

ctx := c.Request.Context()

hashStr := HashSHA512(val[0])

bouncer, err := a.DbClient.SelectBouncer(hashStr)
bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr)
if err != nil {
logger.Errorf("while fetching bouncer info: %s", err)
return nil
Expand All @@ -132,6 +136,8 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
return func(c *gin.Context) {
var bouncer *ent.Bouncer

ctx := c.Request.Context()

clientIP := c.ClientIP()

logger := log.WithField("ip", clientIP)
Expand All @@ -153,7 +159,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
logger = logger.WithField("name", bouncer.Name)

if bouncer.IPAddress == "" {
if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil {
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
Expand All @@ -166,7 +172,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead {
log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress)

if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil {
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
Expand All @@ -182,7 +188,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
}

if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil {
if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil {
logger.Errorf("failed to update bouncer version and type: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
c.Abort()
Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/usage_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func TestRCMetrics(t *testing.T) {
assert.Equal(t, tt.expectedStatusCode, w.Code)
assert.Contains(t, w.Body.String(), tt.expectedResponse)

bouncer, _ := dbClient.SelectBouncerByName("test")
bouncer, _ := dbClient.SelectBouncerByName(ctx, "test")
metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test")

assert.Len(t, metrics, tt.expectedMetricsCount)
Expand Down
36 changes: 18 additions & 18 deletions pkg/database/bouncers.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName strin
return nil
}

func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX)
func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx)
if err != nil {
return nil, err
}

return result, nil
}

func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX)
func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx)
if err != nil {
return nil, err
}
Expand All @@ -68,14 +68,14 @@ func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) {
return result, nil
}

func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) {
func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) {
bouncer, err := c.Ent.Bouncer.
Create().
SetName(name).
SetAPIKey(apiKey).
SetRevoked(false).
SetAuthType(authType).
Save(c.CTX)
Save(ctx)
if err != nil {
if ent.IsConstraintError(err) {
return nil, fmt.Errorf("bouncer %s already exists", name)
Expand All @@ -87,11 +87,11 @@ func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authTy
return bouncer, nil
}

func (c *Client) DeleteBouncer(name string) error {
func (c *Client) DeleteBouncer(ctx context.Context, name string) error {
nbDeleted, err := c.Ent.Bouncer.
Delete().
Where(bouncer.NameEQ(name)).
Exec(c.CTX)
Exec(ctx)
if err != nil {
return err
}
Expand All @@ -103,50 +103,50 @@ func (c *Client) DeleteBouncer(name string) error {
return nil
}

func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) {
func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) {
ids := make([]int, len(bouncers))
for i, b := range bouncers {
ids[i] = b.ID
}

nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(c.CTX)
nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx)
if err != nil {
return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err)
}

return nbDeleted, nil
}

func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error {
func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).
SetLastPull(lastPull).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine last pull in database: %w", err)
}

return nil
}

func (c *Client) UpdateBouncerIP(ipAddr string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(c.CTX)
func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx)
if err != nil {
return fmt.Errorf("unable to update bouncer ip address in database: %w", err)
}

return nil
}

func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(c.CTX)
func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx)
if err != nil {
return fmt.Errorf("unable to update bouncer type and version in database: %w", err)
}

return nil
}

func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) {
func (c *Client) QueryBouncersInactiveSince(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) {
return c.Ent.Bouncer.Query().Where(
// poor man's coalesce
bouncer.Or(
Expand All @@ -156,5 +156,5 @@ func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error)
bouncer.CreatedAtLT(t),
),
),
).All(c.CTX)
).All(ctx)
}

0 comments on commit b3109de

Please sign in to comment.