From 5a55058b0ad1dc2ea94f9e2865acca663754d0e8 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 09:31:51 +0200 Subject: [PATCH 1/6] context propagation: explicit ctx parameter in unit tests This is done so we can later enable a context linter. --- pkg/acquisition/modules/loki/loki_test.go | 16 ++++++--- pkg/apiserver/alerts_test.go | 31 ++++++++++++------ pkg/apiserver/api_key_test.go | 9 +++-- pkg/apiserver/apiserver_test.go | 16 ++++++--- pkg/apiserver/jwt_test.go | 17 ++++++---- pkg/apiserver/machines_test.go | 40 +++++++++++++++-------- 6 files changed, 87 insertions(+), 42 deletions(-) diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index 5f41cd4c62e..ce86a1c36d9 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -261,7 +261,7 @@ func TestConfigureDSN(t *testing.T) { } } -func feedLoki(logger *log.Entry, n int, title string) error { +func feedLoki(ctx context.Context, logger *log.Entry, n int, title string) error { streams := LogStreams{ Streams: []LogStream{ { @@ -286,7 +286,7 @@ func feedLoki(logger *log.Entry, n int, title string) error { return err } - req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) if err != nil { return err } @@ -349,7 +349,9 @@ since: 1h t.Fatalf("Unexpected error : %s", err) } - err = feedLoki(subLogger, 20, title) + ctx := context.Background() + + err = feedLoki(ctx, subLogger, 20, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -421,6 +423,8 @@ query: > }, } + ctx := context.Background() + for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { logger := log.New() @@ -472,7 +476,7 @@ query: > } }) - err = feedLoki(subLogger, ts.expectedLines, title) + err = feedLoki(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -525,7 +529,9 @@ query: > time.Sleep(time.Second * 2) - err = feedLoki(subLogger, 1, title) + ctx := context.Background() + + err = feedLoki(ctx, subLogger, 1, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 891eb3a8f4a..92067554d65 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "net/http" @@ -45,8 +46,9 @@ func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.Response } func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { + ctx := context.Background() w := httptest.NewRecorder() - req, err := http.NewRequest(verb, url, body) + req, err := http.NewRequestWithContext(ctx, verb, url, body) require.NoError(t, err) switch authType { @@ -74,8 +76,9 @@ func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) mo body := CreateTestMachine(t, router, "") ValidateMachine(t, "test", config.API.Server.DbConfig) + ctx := context.Background() w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -355,9 +358,11 @@ func TestCreateAlertErrors(t *testing.T) { lapi := SetupLAPITest(t) alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + //test invalid bearer w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "ratata")) lapi.router.ServeHTTP(w, req) @@ -365,7 +370,7 @@ func TestCreateAlertErrors(t *testing.T) { //test invalid bearer w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) lapi.router.ServeHTTP(w, req) @@ -376,9 +381,11 @@ func TestDeleteAlert(t *testing.T) { lapi := SetupLAPITest(t) lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -387,7 +394,7 @@ func TestDeleteAlert(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -399,9 +406,11 @@ func TestDeleteAlertByID(t *testing.T) { lapi := SetupLAPITest(t) lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -410,7 +419,7 @@ func TestDeleteAlertByID(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -439,9 +448,11 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { loginResp: loginResp, } + ctx := context.Background() + assertAlertDeleteFailedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -453,7 +464,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assertAlertDeletedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index 883ff21298d..10e75ae47f1 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -12,11 +13,13 @@ import ( func TestAPIKey(t *testing.T) { router, config := NewAPITest(t) + ctx := context.Background() + APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) // Login with empty token w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -25,7 +28,7 @@ func TestAPIKey(t *testing.T) { // Login with invalid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", "a1b2c3d4e5f6") router.ServeHTTP(w, req) @@ -35,7 +38,7 @@ func TestAPIKey(t *testing.T) { // Login with valid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", APIKey) router.ServeHTTP(w, req) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index f48791ebcb8..89c75f35d21 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -278,8 +278,10 @@ func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string { body := string(b) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -323,8 +325,10 @@ func TestWithWrongFlushConfig(t *testing.T) { func TestUnknownPath(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -380,8 +384,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, api) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) assert.Equal(t, 404, w.Code) @@ -430,8 +436,10 @@ func TestLoggingErrorToFileConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, api) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index aa6e84e416b..7ef010ae12b 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -12,11 +13,13 @@ import ( func TestLogin(t *testing.T) { router, config := NewAPITest(t) + ctx := context.Background() + body := CreateTestMachine(t, router, "") // Login with machine not validated yet w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -25,7 +28,7 @@ func TestLogin(t *testing.T) { // Login with machine not exist w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -34,7 +37,7 @@ func TestLogin(t *testing.T) { // Login with invalid body w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -43,7 +46,7 @@ func TestLogin(t *testing.T) { // Login with invalid format w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -55,7 +58,7 @@ func TestLogin(t *testing.T) { // Login with invalid password w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -64,7 +67,7 @@ func TestLogin(t *testing.T) { // Login with valid machine w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -74,7 +77,7 @@ func TestLogin(t *testing.T) { // Login with valid machine + scenarios w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 041a6bee528..61677ca3dc4 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -16,9 +17,11 @@ import ( func TestCreateMachine(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + // Create machine with invalid format w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("test")) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -27,7 +30,7 @@ func TestCreateMachine(t *testing.T) { // Create machine with invalid input w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -41,7 +44,7 @@ func TestCreateMachine(t *testing.T) { body := string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -52,6 +55,9 @@ func TestCreateMachine(t *testing.T) { func TestCreateMachineWithForwardedFor(t *testing.T) { router, config := NewAPITestForwardedFor(t) router.TrustedPlatform = "X-Real-IP" + + ctx := context.Background() + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -59,7 +65,7 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-Ip", "1.1.1.1") router.ServeHTTP(w, req) @@ -75,6 +81,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { router, config := NewAPITest(t) + ctx := context.Background() + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -82,7 +90,7 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-IP", "1.1.1.1") router.ServeHTTP(w, req) @@ -100,6 +108,8 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { func TestCreateMachineWithoutForwardedFor(t *testing.T) { router, config := NewAPITestForwardedFor(t) + ctx := context.Background() + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -107,7 +117,7 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -124,15 +134,17 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { func TestCreateMachineAlreadyExist(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + body := CreateTestMachine(t, router, "") w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -143,6 +155,8 @@ func TestCreateMachineAlreadyExist(t *testing.T) { func TestAutoRegistration(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + //Invalid registration token / valid source IP regReq := MachineTest regReq.RegistrationToken = invalidRegistrationToken @@ -152,7 +166,7 @@ func TestAutoRegistration(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) @@ -168,7 +182,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "42.42.42.42:4242" router.ServeHTTP(w, req) @@ -184,7 +198,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "42.42.42.42:4242" router.ServeHTTP(w, req) @@ -200,7 +214,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) @@ -216,7 +230,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) From de1346d4601fabe71316b91287fb39fe38962765 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 10:30:53 +0200 Subject: [PATCH 2/6] context propagation: pass context to NewAPIC() --- cmd/crowdsec-cli/clipapi/papi.go | 12 +++++++----- pkg/apiserver/apic.go | 4 ++-- pkg/apiserver/apic_test.go | 4 +++- pkg/apiserver/apiserver.go | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index 747b8c01b9b..c0f08157f31 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -59,7 +59,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command { func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error { cfg := cli.cfg() - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } @@ -118,11 +118,11 @@ func (cli *cliPapi) newStatusCmd() *cobra.Command { return cmd } -func (cli *cliPapi) sync(out io.Writer, db *database.Client) error { +func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client) error { cfg := cli.cfg() t := tomb.Tomb{} - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } @@ -159,12 +159,14 @@ func (cli *cliPapi) newSyncCmd() *cobra.Command { DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { cfg := cli.cfg() - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) if err != nil { return err } - return cli.sync(color.Output, db) + return cli.sync(ctx, color.Output, db) }, } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 73061637ad9..3ed2e12ea54 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -174,7 +174,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) return signal } -func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { +func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { var err error ret := &apic{ @@ -237,7 +237,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con return ret, fmt.Errorf("get scenario in db: %w", err) } - authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &config.Credentials.Login, Password: &password, Scenarios: scenarios, diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 51887006ad4..328d5c4ae09 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -230,6 +230,8 @@ func TestNewAPIC(t *testing.T) { }, } + ctx := context.Background() + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { setConfig() @@ -246,7 +248,7 @@ func TestNewAPIC(t *testing.T) { ), )) tc.action() - _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) + _, err := NewAPIC(ctx, testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 42dcb219379..8bf406e0a79 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -249,7 +249,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { log.Printf("Loading CAPI manager") - apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) + apiClient, err = NewAPIC(ctx, config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { return nil, err } From 6bfd6f1c09db165b5c1c2018a0566a24e34faf55 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 10:45:41 +0200 Subject: [PATCH 3/6] context propagation: drop field S3Source.ctx, pass explicitly --- pkg/acquisition/modules/s3/s3.go | 41 +++++++++++++++++--------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index 9ef4d2ba757..f4aee37d859 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -55,7 +55,6 @@ type S3Source struct { readerChan chan S3Object t *tomb.Tomb out chan types.Event - ctx aws.Context cancel context.CancelFunc } @@ -182,7 +181,7 @@ func (s *S3Source) newSQSClient() error { return nil } -func (s *S3Source) readManager() { +func (s *S3Source) readManager(ctx context.Context) { logger := s.logger.WithField("method", "readManager") for { select { @@ -192,7 +191,7 @@ func (s *S3Source) readManager() { return case s3Object := <-s.readerChan: logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key) - err := s.readFile(s3Object.Bucket, s3Object.Key) + err := s.readFile(ctx, s3Object.Bucket, s3Object.Key) if err != nil { logger.Errorf("Error while reading file: %s", err) } @@ -200,13 +199,13 @@ func (s *S3Source) readManager() { } } -func (s *S3Source) getBucketContent() ([]*s3.Object, error) { +func (s *S3Source) getBucketContent(ctx context.Context) ([]*s3.Object, error) { logger := s.logger.WithField("method", "getBucketContent") logger.Debugf("Getting bucket content for %s", s.Config.BucketName) bucketObjects := make([]*s3.Object, 0) var continuationToken *string for { - out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{ + out, err := s.s3Client.ListObjectsV2WithContext(ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(s.Config.BucketName), Prefix: aws.String(s.Config.Prefix), ContinuationToken: continuationToken, @@ -227,7 +226,7 @@ func (s *S3Source) getBucketContent() ([]*s3.Object, error) { return bucketObjects, nil } -func (s *S3Source) listPoll() error { +func (s *S3Source) listPoll(ctx context.Context) error { logger := s.logger.WithField("method", "listPoll") ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second) lastObjectDate := time.Now() @@ -241,7 +240,7 @@ func (s *S3Source) listPoll() error { return nil case <-ticker.C: newObject := false - bucketObjects, err := s.getBucketContent() + bucketObjects, err := s.getBucketContent(ctx) if err != nil { logger.Errorf("Error while getting bucket content: %s", err) continue @@ -323,7 +322,7 @@ func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, erro } } -func (s *S3Source) sqsPoll() error { +func (s *S3Source) sqsPoll(ctx context.Context) error { logger := s.logger.WithField("method", "sqsPoll") for { select { @@ -333,7 +332,7 @@ func (s *S3Source) sqsPoll() error { return nil default: logger.Trace("Polling SQS queue") - out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{ + out, err := s.sqsClient.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: aws.Int64(10), WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ? @@ -376,7 +375,7 @@ func (s *S3Source) sqsPoll() error { } } -func (s *S3Source) readFile(bucket string, key string) error { +func (s *S3Source) readFile(ctx context.Context, bucket string, key string) error { //TODO: Handle SSE-C var scanner *bufio.Scanner @@ -386,7 +385,7 @@ func (s *S3Source) readFile(bucket string, key string) error { "key": key, }) - output, err := s.s3Client.GetObjectWithContext(s.ctx, &s3.GetObjectInput{ + output, err := s.s3Client.GetObjectWithContext(ctx, &s3.GetObjectInput{ Bucket: aws.String(bucket), Key: aws.String(key), }) @@ -642,24 +641,26 @@ func (s *S3Source) GetName() string { } func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + var ctx context.Context + s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key) s.out = out - s.ctx, s.cancel = context.WithCancel(context.Background()) + ctx, s.cancel = context.WithCancel(context.Background()) s.Config.UseTimeMachine = true s.t = t if s.Config.Key != "" { - err := s.readFile(s.Config.BucketName, s.Config.Key) + err := s.readFile(ctx, s.Config.BucketName, s.Config.Key) if err != nil { return err } } else { //No key, get everything in the bucket based on the prefix - objects, err := s.getBucketContent() + objects, err := s.getBucketContent(ctx) if err != nil { return err } for _, object := range objects { - err := s.readFile(s.Config.BucketName, *object.Key) + err := s.readFile(ctx, s.Config.BucketName, *object.Key) if err != nil { return err } @@ -670,18 +671,20 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error } func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { + var ctx context.Context + s.t = t s.out = out s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(context.Background()) + ctx, s.cancel = context.WithCancel(context.Background()) s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) t.Go(func() error { - s.readManager() + s.readManager(ctx) return nil }) if s.Config.PollingMethod == PollMethodSQS { t.Go(func() error { - err := s.sqsPoll() + err := s.sqsPoll(ctx) if err != nil { return err } @@ -689,7 +692,7 @@ func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) erro }) } else { t.Go(func() error { - err := s.listPoll() + err := s.listPoll(ctx) if err != nil { return err } From 5273f66eee33e79f0a2824b6fbfd55f22eb48f44 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 13:16:53 +0200 Subject: [PATCH 4/6] context propagation: pkg/database/flush --- cmd/crowdsec-cli/clialert/alerts.go | 6 ++-- pkg/apiserver/apiserver.go | 2 +- pkg/database/flush.go | 43 +++++++++++++++-------------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/cmd/crowdsec-cli/clialert/alerts.go b/cmd/crowdsec-cli/clialert/alerts.go index 757a84927e5..0bd0c0c1574 100644 --- a/cmd/crowdsec-cli/clialert/alerts.go +++ b/cmd/crowdsec-cli/clialert/alerts.go @@ -575,15 +575,17 @@ func (cli *cliAlerts) newFlushCmd() *cobra.Command { DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { cfg := cli.cfg() + ctx := cmd.Context() + if err := require.LAPI(cfg); err != nil { return err } - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) + db, err := require.DBClient(ctx, cfg.DbConfig) if err != nil { return err } log.Info("Flushing alerts. !! This may take a long time !!") - err = db.FlushAlerts(maxAge, maxItems) + err = db.FlushAlerts(ctx, maxAge, maxItems) if err != nil { return fmt.Errorf("unable to flush alerts: %w", err) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 42dcb219379..c73e854fc11 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -170,7 +170,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { } if config.DbConfig.Flush != nil { - flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) + flushScheduler, err = dbClient.StartFlushScheduler(ctx, config.DbConfig.Flush) if err != nil { return nil, err } diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 5d53d10c942..46c8edfa308 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -1,6 +1,7 @@ package database import ( + "context" "errors" "fmt" "time" @@ -26,7 +27,7 @@ const ( flushInterval = 1 * time.Minute ) -func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { +func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { maxItems := 0 maxAge := "" @@ -45,7 +46,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched // Init & Start cronjob every minute for alerts scheduler := gocron.NewScheduler(time.UTC) - job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, ctx, maxAge, maxItems) if err != nil { return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) } @@ -100,14 +101,14 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched } } - baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) + baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, ctx, config.AgentsGC, config.BouncersGC) if err != nil { return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) } baJob.SingletonMode() - metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, config.MetricsMaxAge) + metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, ctx, config.MetricsMaxAge) if err != nil { return nil, fmt.Errorf("while starting flushMetrics scheduler: %w", err) } @@ -120,7 +121,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched } // flushMetrics deletes metrics older than maxAge, regardless if they have been pushed to CAPI or not -func (c *Client) flushMetrics(maxAge *time.Duration) { +func (c *Client) flushMetrics(ctx context.Context, maxAge *time.Duration) { if maxAge == nil { maxAge = ptr.Of(defaultMetricsMaxAge) } @@ -129,7 +130,7 @@ func (c *Client) flushMetrics(maxAge *time.Duration) { deleted, err := c.Ent.Metric.Delete().Where( metric.ReceivedAtLTE(time.Now().UTC().Add(-*maxAge)), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while flushing metrics: %s", err) return @@ -140,10 +141,10 @@ func (c *Client) flushMetrics(maxAge *time.Duration) { } } -func (c *Client) FlushOrphans() { +func (c *Client) FlushOrphans(ctx context.Context) { /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ - eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) + eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan events: %s", err) return @@ -154,7 +155,7 @@ func (c *Client) FlushOrphans() { } eventsCount, err = c.Ent.Decision.Delete().Where( - decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) + decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan decisions: %s", err) return @@ -165,7 +166,7 @@ func (c *Client) FlushOrphans() { } } -func (c *Client) flushBouncers(authType string, duration *time.Duration) { +func (c *Client) flushBouncers(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -174,7 +175,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { bouncer.LastPullLTE(time.Now().UTC().Add(-*duration)), ).Where( bouncer.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired bouncers (%s): %s", authType, err) return @@ -185,7 +186,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { } } -func (c *Client) flushAgents(authType string, duration *time.Duration) { +func (c *Client) flushAgents(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -194,7 +195,7 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { machine.LastHeartbeatLTE(time.Now().UTC().Add(-*duration)), machine.Not(machine.HasAlerts()), machine.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired machines (%s): %s", authType, err) return @@ -205,23 +206,23 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { } } -func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { +func (c *Client) FlushAgentsAndBouncers(ctx context.Context, agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { log.Debug("starting FlushAgentsAndBouncers") if agentsCfg != nil { - c.flushAgents(types.TlsAuthType, agentsCfg.CertDuration) - c.flushAgents(types.PasswordAuthType, agentsCfg.LoginPasswordDuration) + c.flushAgents(ctx, types.TlsAuthType, agentsCfg.CertDuration) + c.flushAgents(ctx, types.PasswordAuthType, agentsCfg.LoginPasswordDuration) } if bouncersCfg != nil { - c.flushBouncers(types.TlsAuthType, bouncersCfg.CertDuration) - c.flushBouncers(types.ApiKeyAuthType, bouncersCfg.ApiDuration) + c.flushBouncers(ctx, types.TlsAuthType, bouncersCfg.CertDuration) + c.flushBouncers(ctx, types.ApiKeyAuthType, bouncersCfg.ApiDuration) } return nil } -func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { +func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) error { var ( deletedByAge int deletedByNbItem int @@ -235,7 +236,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { } c.Log.Debug("Flushing orphan alerts") - c.FlushOrphans() + c.FlushOrphans(ctx) c.Log.Debug("Done flushing orphan alerts") totalAlerts, err = c.TotalAlerts() @@ -287,7 +288,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { if maxid > 0 { // This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted - deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) + deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(ctx) if err != nil { c.Log.Errorf("FlushAlerts: Could not delete alerts: %s", err) return fmt.Errorf("could not delete alerts: %w", err) From c0c2b8b066f777d0c398e94231f54da8f89a3ab4 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 13:28:40 +0200 Subject: [PATCH 5/6] context propagation: bouncer list --- cmd/crowdsec-cli/clibouncer/bouncers.go | 14 ++++++++------ cmd/crowdsec-cli/clisupport/support.go | 6 +++--- pkg/apiserver/apic_metrics.go | 16 ++++++++++------ pkg/apiserver/apic_test.go | 8 +++++--- pkg/database/bouncers.go | 4 ++-- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go index 0d1484bcc6b..177ed6c9fb6 100644 --- a/cmd/crowdsec-cli/clibouncer/bouncers.go +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -1,6 +1,7 @@ package clibouncer import ( + "context" "encoding/csv" "encoding/json" "errors" @@ -159,11 +160,11 @@ func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error { return nil } -func (cli *cliBouncers) List(out io.Writer, db *database.Client) error { +func (cli *cliBouncers) List(ctx context.Context, out io.Writer, db *database.Client) error { // XXX: must use the provided db object, the one in the struct might be nil // (calling List directly skips the PersistentPreRunE) - bouncers, err := db.ListBouncers() + bouncers, err := db.ListBouncers(ctx) if err != nil { return fmt.Errorf("unable to list bouncers: %w", err) } @@ -199,8 +200,8 @@ func (cli *cliBouncers) newListCmd() *cobra.Command { Example: `cscli bouncers list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.List(color.Output, cli.db) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) }, } @@ -271,6 +272,7 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp var err error cfg := cli.cfg() + ctx := cmd.Context() // need to load config and db because PersistentPreRunE is not called for completions @@ -279,13 +281,13 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp return nil, cobra.ShellCompDirectiveNoFileComp } - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + cli.db, err = require.DBClient(ctx, cfg.DbConfig) if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp } - bouncers, err := cli.db.ListBouncers() + bouncers, err := cli.db.ListBouncers(ctx) if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp diff --git a/cmd/crowdsec-cli/clisupport/support.go b/cmd/crowdsec-cli/clisupport/support.go index e9837b03fe7..7e41518805a 100644 --- a/cmd/crowdsec-cli/clisupport/support.go +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -189,7 +189,7 @@ func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub) error { return nil } -func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting bouncers") if db == nil { @@ -199,7 +199,7 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { out := new(bytes.Buffer) cb := clibouncer.New(cli.cfg) - if err := cb.List(out, db); err != nil { + if err := cb.List(ctx, out, db); err != nil { return err } @@ -525,7 +525,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error { log.Warnf("could not collect hub information: %s", err) } - if err = cli.dumpBouncers(zipWriter, db); err != nil { + if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil { log.Warnf("could not collect bouncers information: %s", err) } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 176984f1ad6..380690379a6 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -23,7 +23,7 @@ type dbPayload struct { Metrics []*models.DetailedMetrics `json:"metrics"` } -func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { +func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, error) { allMetrics := &models.AllMetrics{} metricsIds := make([]int, 0) @@ -32,7 +32,7 @@ func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { return nil, nil, err } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, nil, err } @@ -185,7 +185,7 @@ func (a *apic) MarkUsageMetricsAsSent(ids []int) error { return a.dbClient.MarkUsageMetricsAsSent(ids) } -func (a *apic) GetMetrics() (*models.Metrics, error) { +func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { machines, err := a.dbClient.ListMachines() if err != nil { return nil, err @@ -202,7 +202,7 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { } } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, err } @@ -254,6 +254,8 @@ func (a *apic) fetchMachineIDs() ([]string, error) { func (a *apic) SendMetrics(stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") + ctx := context.TODO() + // verify the list of machines every interval const checkInt = 20 * time.Second @@ -311,7 +313,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-metTicker.C: metTicker.Stop() - metrics, err := a.GetMetrics() + metrics, err := a.GetMetrics(ctx) if err != nil { log.Errorf("unable to get metrics (%s)", err) } @@ -340,6 +342,8 @@ func (a *apic) SendMetrics(stop chan (bool)) { func (a *apic) SendUsageMetrics() { defer trace.CatchPanic("lapi/usageMetricsToAPIC") + ctx := context.TODO() + firstRun := true log.Debugf("Start sending usage metrics to CrowdSec Central API (interval: %s once, then %s)", a.usageMetricsIntervalFirst, a.usageMetricsInterval) @@ -358,7 +362,7 @@ func (a *apic) SendUsageMetrics() { ticker.Reset(a.usageMetricsInterval) } - metrics, metricsId, err := a.GetUsageMetrics() + metrics, metricsId, err := a.GetUsageMetrics(ctx) if err != nil { log.Errorf("unable to get usage metrics: %s", err) continue diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 51887006ad4..65fba1f3e15 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -290,9 +290,11 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { } func TestAPICGetMetrics(t *testing.T) { + ctx := context.Background() + cleanUp := func(api *apic) { - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) } tests := []struct { name string @@ -375,7 +377,7 @@ func TestAPICGetMetrics(t *testing.T) { ExecX(context.Background()) } - foundMetrics, err := apiClient.GetMetrics() + foundMetrics, err := apiClient.GetMetrics(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index a7378bbb203..6ff308ff786 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -59,8 +59,8 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().All(c.CTX) +func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) } From c50fb39e0605ec5b0d5d52b3629874bd7129dd5d Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 10:35:06 +0200 Subject: [PATCH 6/6] context propagation: pkg/database/config --- cmd/crowdsec-cli/clipapi/papi.go | 2 +- pkg/apiserver/apic.go | 28 ++++++++++++++-------------- pkg/apiserver/apic_test.go | 24 +++++++++++++++--------- pkg/apiserver/apiserver.go | 18 ++++++++++-------- pkg/apiserver/papi.go | 8 ++++---- pkg/apiserver/papi_cmd.go | 7 +++++-- pkg/database/config.go | 12 ++++++------ 7 files changed, 55 insertions(+), 44 deletions(-) diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index c0f08157f31..b8101a0fb34 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -74,7 +74,7 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie return fmt.Errorf("unable to get PAPI permissions: %w", err) } - lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey) + lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey) if err != nil { lastTimestampStr = ptr.Of("never") } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 3ed2e12ea54..b5384c6cc5c 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -614,7 +614,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio // we receive a list of decisions and links for blocklist and we need to create a list of alerts : // one alert for "community blocklist" // one alert per list we're subscribed to -func (a *apic) PullTop(forcePull bool) error { +func (a *apic) PullTop(ctx context.Context, forcePull bool) error { var err error // A mutex with TryLock would be a bit simpler @@ -655,7 +655,7 @@ func (a *apic) PullTop(forcePull bool) error { log.Infof("Starting community-blocklist update") - data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup}) + data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup}) if err != nil { return fmt.Errorf("get stream: %w", err) } @@ -700,7 +700,7 @@ func (a *apic) PullTop(forcePull bool) error { } // update blocklists - if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil { + if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } @@ -708,9 +708,9 @@ func (a *apic) PullTop(forcePull bool) error { } // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert -func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { +func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error { addCounters, _ := makeAddAndDeleteCounters() - if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ + if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{blocklist}, }, addCounters, forcePull); err != nil { return fmt.Errorf("while pulling blocklist: %w", err) @@ -820,7 +820,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo return false, nil } -func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { if blocklist.Scope == nil { log.Warningf("blocklist has no scope") return nil @@ -848,13 +848,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap ) if !forcePull { - lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName) if err != nil { return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } } - decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) + decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp) if err != nil { return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) } @@ -869,7 +869,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } - err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) if err != nil { return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } @@ -892,7 +892,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } @@ -908,7 +908,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } for _, blocklist := range links.Blocklists { - if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil { + if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil { return err } } @@ -931,7 +931,7 @@ func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int } } -func (a *apic) Pull() error { +func (a *apic) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/pullFromAPIC") toldOnce := false @@ -955,7 +955,7 @@ func (a *apic) Pull() error { time.Sleep(1 * time.Second) } - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -967,7 +967,7 @@ func (a *apic) Pull() error { case <-ticker.C: ticker.Reset(a.pullInterval) - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) continue } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 182bf18532f..97943b495e5 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -550,6 +550,7 @@ func TestFillAlertsWithDecisions(t *testing.T) { } func TestAPICWhitelists(t *testing.T) { + ctx := context.Background() api := getAPIC(t) // one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} @@ -685,7 +686,7 @@ func TestAPICWhitelists(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing @@ -736,6 +737,7 @@ func TestAPICWhitelists(t *testing.T) { } func TestAPICPullTop(t *testing.T) { + ctx := context.Background() api := getAPIC(t) api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). @@ -826,7 +828,7 @@ func TestAPICPullTop(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) @@ -860,6 +862,7 @@ func TestAPICPullTop(t *testing.T) { } func TestAPICPullTopBLCacheFirstCall(t *testing.T) { + ctx := context.Background() // no decision in db, no last modified parameter. api := getAPIC(t) @@ -913,11 +916,11 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) blocklistConfigItemName := "blocklist:blocklist1:last_pull" - lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.NotEqual(t, "", *lastPullTimestamp) @@ -927,14 +930,15 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { return httpmock.NewStringResponse(304, ""), nil }) - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) } func TestAPICPullTopBLCacheForceCall(t *testing.T) { + ctx := context.Background() api := getAPIC(t) httpmock.Activate() @@ -1005,11 +1009,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) } func TestAPICPullBlocklistCall(t *testing.T) { + ctx := context.Background() api := getAPIC(t) httpmock.Activate() @@ -1032,7 +1037,7 @@ func TestAPICPullBlocklistCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullBlocklist(&modelscapi.BlocklistLink{ + err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{ URL: ptr.Of("http://api.crowdsec.net/blocklist1"), Name: ptr.Of("blocklist1"), Scope: ptr.Of("Ip"), @@ -1134,6 +1139,7 @@ func TestAPICPush(t *testing.T) { } func TestAPICPull(t *testing.T) { + ctx := context.Background() api := getAPIC(t) tests := []struct { name string @@ -1204,7 +1210,7 @@ func TestAPICPull(t *testing.T) { go func() { logrus.SetOutput(&buf) - if err := api.Pull(); err != nil { + if err := api.Pull(ctx); err != nil { panic(err) } }() diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 95d18ccb028..6b5d6803be9 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -310,8 +310,8 @@ func (s *APIServer) apicPush() error { return nil } -func (s *APIServer) apicPull() error { - if err := s.apic.Pull(); err != nil { +func (s *APIServer) apicPull(ctx context.Context) error { + if err := s.apic.Pull(ctx); err != nil { log.Errorf("capi pull: %s", err) return err } @@ -319,8 +319,8 @@ func (s *APIServer) apicPull() error { return nil } -func (s *APIServer) papiPull() error { - if err := s.papi.Pull(); err != nil { +func (s *APIServer) papiPull(ctx context.Context) error { + if err := s.papi.Pull(ctx); err != nil { log.Errorf("papi pull: %s", err) return err } @@ -337,16 +337,16 @@ func (s *APIServer) papiSync() error { return nil } -func (s *APIServer) initAPIC() { +func (s *APIServer) initAPIC(ctx context.Context) { s.apic.pushTomb.Go(s.apicPush) - s.apic.pullTomb.Go(s.apicPull) + s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) }) // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios if s.apic.apiClient.IsEnrolled() { if s.consoleConfig.IsPAPIEnabled() { if s.papi.URL != "" { log.Info("Starting PAPI decision receiver") - s.papi.pullTomb.Go(s.papiPull) + s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) }) s.papi.syncTomb.Go(s.papiSync) } else { log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") @@ -381,8 +381,10 @@ func (s *APIServer) Run(apiReady chan bool) error { TLSConfig: tlsCfg, } + ctx := context.TODO() + if s.apic != nil { - s.initAPIC() + s.initAPIC(ctx) } s.httpServerTomb.Go(func() error { diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 89ad93930a1..7dd6b346aa9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -230,13 +230,13 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { } // PullPAPI is the long polling client for real-time decisions from PAPI -func (p *Papi) Pull() error { +func (p *Papi) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/PullPAPI") p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} - lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) + lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } @@ -248,7 +248,7 @@ func (p *Papi) Pull() error { return fmt.Errorf("failed to serialize last timestamp: %w", err) } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) @@ -277,7 +277,7 @@ func (p *Papi) Pull() error { continue } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index a1137161698..943eb4139de 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "time" @@ -215,17 +216,19 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } + ctx := context.TODO() + if forcePullMsg.Blocklist == nil { p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err = p.apic.PullTop(true) + err = p.apic.PullTop(ctx, true) if err != nil { return fmt.Errorf("failed to force pull operation: %w", err) } } else { p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) - err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{ + err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{ Name: &forcePullMsg.Blocklist.Name, URL: &forcePullMsg.Blocklist.Url, Remediation: &forcePullMsg.Blocklist.Remediation, diff --git a/pkg/database/config.go b/pkg/database/config.go index 8c3578ad596..7c341d3ecda 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -1,14 +1,15 @@ package database import ( + "context" "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(key string) (*string, error) { - result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX) +func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { + result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) if err != nil && ent.IsNotFound(err) { return nil, nil } @@ -19,11 +20,10 @@ func (c *Client) GetConfigItem(key string) (*string, error) { return &result.Value, nil } -func (c *Client) SetConfigItem(key string, value string) error { - - nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX) +func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { + nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create - err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX) + err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) }