diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go index 0c0fc8851c9..89e91b63911 100644 --- a/cmd/crowdsec-cli/clibouncer/bouncers.go +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -1,6 +1,7 @@ package clibouncer import ( + "context" "encoding/csv" "encoding/json" "errors" @@ -159,11 +160,11 @@ func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error { return nil } -func (cli *cliBouncers) List(out io.Writer, db *database.Client) error { +func (cli *cliBouncers) List(ctx context.Context, out io.Writer, db *database.Client) error { // XXX: must use the provided db object, the one in the struct might be nil // (calling List directly skips the PersistentPreRunE) - bouncers, err := db.ListBouncers() + bouncers, err := db.ListBouncers(ctx) if err != nil { return fmt.Errorf("unable to list bouncers: %w", err) } @@ -199,8 +200,8 @@ func (cli *cliBouncers) newListCmd() *cobra.Command { Example: `cscli bouncers list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.List(color.Output, cli.db) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) }, } @@ -271,6 +272,7 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp var err error cfg := cli.cfg() + ctx := cmd.Context() // need to load config and db because PersistentPreRunE is not called for completions @@ -279,13 +281,13 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp return nil, cobra.ShellCompDirectiveNoFileComp } - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + cli.db, err = require.DBClient(ctx, cfg.DbConfig) if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp } - bouncers, err := cli.db.ListBouncers() + bouncers, err := cli.db.ListBouncers(ctx) if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp diff --git a/cmd/crowdsec-cli/clisupport/support.go b/cmd/crowdsec-cli/clisupport/support.go index e9837b03fe7..7e41518805a 100644 --- a/cmd/crowdsec-cli/clisupport/support.go +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -189,7 +189,7 @@ func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub) error { return nil } -func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting bouncers") if db == nil { @@ -199,7 +199,7 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { out := new(bytes.Buffer) cb := clibouncer.New(cli.cfg) - if err := cb.List(out, db); err != nil { + if err := cb.List(ctx, out, db); err != nil { return err } @@ -525,7 +525,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error { log.Warnf("could not collect hub information: %s", err) } - if err = cli.dumpBouncers(zipWriter, db); err != nil { + if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil { log.Warnf("could not collect bouncers information: %s", err) } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 5c6a550a6a0..91a0a8273f7 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -23,7 +23,7 @@ type dbPayload struct { Metrics []*models.DetailedMetrics `json:"metrics"` } -func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { +func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, error) { allMetrics := &models.AllMetrics{} metricsIds := make([]int, 0) @@ -32,7 +32,7 @@ func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { return nil, nil, err } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, nil, err } @@ -185,7 +185,7 @@ func (a *apic) MarkUsageMetricsAsSent(ids []int) error { return a.dbClient.MarkUsageMetricsAsSent(ids) } -func (a *apic) GetMetrics() (*models.Metrics, error) { +func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { machines, err := a.dbClient.ListMachines() if err != nil { return nil, err @@ -202,7 +202,7 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { } } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, err } @@ -254,6 +254,8 @@ func (a *apic) fetchMachineIDs() ([]string, error) { func (a *apic) SendMetrics(stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") + ctx := context.TODO() + // verify the list of machines every interval const checkInt = 20 * time.Second @@ -311,7 +313,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-metTicker.C: metTicker.Stop() - metrics, err := a.GetMetrics() + metrics, err := a.GetMetrics(ctx) if err != nil { log.Errorf("unable to get metrics (%s)", err) } @@ -340,6 +342,8 @@ func (a *apic) SendMetrics(stop chan (bool)) { func (a *apic) SendUsageMetrics() { defer trace.CatchPanic("lapi/usageMetricsToAPIC") + ctx := context.TODO() + firstRun := true log.Debugf("Start sending usage metrics to CrowdSec Central API (interval: %s once, then %s)", a.usageMetricsIntervalFirst, a.usageMetricsInterval) @@ -358,7 +362,7 @@ func (a *apic) SendUsageMetrics() { ticker.Reset(a.usageMetricsInterval) } - metrics, metricsId, err := a.GetUsageMetrics() + metrics, metricsId, err := a.GetUsageMetrics(ctx) if err != nil { log.Errorf("unable to get usage metrics: %s", err) continue diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 105d295dd0d..182bf18532f 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -292,9 +292,11 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { } func TestAPICGetMetrics(t *testing.T) { + ctx := context.Background() + cleanUp := func(api *apic) { - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) } tests := []struct { name string @@ -377,7 +379,7 @@ func TestAPICGetMetrics(t *testing.T) { ExecX(context.Background()) } - foundMetrics, err := apiClient.GetMetrics() + foundMetrics, err := apiClient.GetMetrics(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index a7378bbb203..6ff308ff786 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -59,8 +59,8 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().All(c.CTX) +func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) }