From 58844ab1120b375aebf5afae9e35ad99b968dbf8 Mon Sep 17 00:00:00 2001 From: marco Date: Mon, 16 Sep 2024 09:20:22 +0200 Subject: [PATCH 01/28] refact alerts: log messages --- pkg/database/alerts.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 3e3e480c7d6..12669b00e64 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -456,14 +456,14 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) + c.Log.Errorf("creating alert: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) startAtTime = time.Now().UTC() } stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) + c.Log.Errorf("creating alert: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) stopAtTime = time.Now().UTC() } @@ -483,7 +483,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for i, eventItem := range alertItem.Events { ts, err := time.Parse(time.RFC3339, *eventItem.Timestamp) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) + c.Log.Errorf("creating alert: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) ts = time.Now().UTC() } @@ -694,7 +694,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str return nil, fmt.Errorf("machine '%s': %w", machineID, err) } - c.Log.Debugf("CreateAlertBulk: Machine Id %s doesn't exist", machineID) + c.Log.Debugf("creating alert: machine %s doesn't exist", machineID) owner = nil } From 2f04f1511e15e4ad964340ad39e0ef9c4819c9b3 Mon Sep 17 00:00:00 2001 From: marco Date: Mon, 16 Sep 2024 09:57:06 +0200 Subject: [PATCH 02/28] refact: AlertPredicatesFromFilter --- .golangci.yml | 2 +- pkg/database/alerts.go | 304 ++++++++++++++++++++++------------------- 2 files changed, 166 insertions(+), 140 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 78b666d25b4..9af7eed0dd2 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -103,7 +103,7 @@ linters-settings: disabled: true - name: cyclomatic # lower this after refactoring - arguments: [41] + arguments: [39] - name: defer disabled: true - name: empty-block diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 12669b00e64..3dfb0dc8197 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -724,6 +724,160 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str return alertIDs, nil } +func handleSimulatedFilter(filter map[string][]string, predicates *[]predicate.Alert) { + /* the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ + if v, ok := filter["simulated"]; ok && v[0] == "false" { + *predicates = append(*predicates, alert.SimulatedEQ(false)) + } +} + +func handleOriginFilter(filter map[string][]string, predicates *[]predicate.Alert) { + if _, ok := filter["origin"]; ok { + filter["include_capi"] = []string{"true"} + } +} + +func handleScopeFilter(scope string, predicates *[]predicate.Alert) { + if strings.ToLower(scope) == "ip" { + scope = types.Ip + } else if strings.ToLower(scope) == "range" { + scope = types.Range + } + + *predicates = append(*predicates, alert.SourceScopeEQ(scope)) +} + +func handleTimeFilters(param, value string, predicates *[]predicate.Alert) error { + duration, err := ParseDuration(value) + if err != nil { + return fmt.Errorf("while parsing duration: %w", err) + } + + timePoint := time.Now().UTC().Add(-duration) + if timePoint.IsZero() { + return fmt.Errorf("empty time now() - %s", timePoint.String()) + } + + switch param { + case "since": + *predicates = append(*predicates, alert.StartedAtGTE(timePoint)) + case "created_before": + *predicates = append(*predicates, alert.CreatedAtLTE(timePoint)) + case "until": + *predicates = append(*predicates, alert.StartedAtLTE(timePoint)) + } + + return nil +} + +func handleIPv4Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } +} + +func handleIPv6Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip < query.start_ip + alert.HasDecisionsWith(decision.StartIPLT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix <= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip > query.end_ip + alert.HasDecisionsWith(decision.EndIPGT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix >= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), + ), + ), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip > query.start_ip + alert.HasDecisionsWith(decision.StartIPGT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix >= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip < query.end_ip + alert.HasDecisionsWith(decision.EndIPLT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix <= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), + ), + ), + )) + } +} + +func handleIPPredicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) error { + if ip_sz == 4 { + handleIPv4Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz == 16 { + handleIPv6Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz != 0 { + return errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + } + + return nil +} + +func handleIncludeCapiFilter(value string, predicates *[]predicate.Alert) error { + if value == "false" { + *predicates = append(*predicates, alert.And( + // do not show alerts with active decisions having origin CAPI or lists + alert.And( + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.CAPIOrigin))), + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.ListOrigin))), + ), + alert.Not( + alert.And( + // do not show neither alerts with no decisions if the Source Scope is lists: or CAPI + alert.Not(alert.HasDecisions()), + alert.Or( + alert.SourceScopeHasPrefix(types.ListOrigin+":"), + alert.SourceScopeEQ(types.CommunityBlocklistPullSourceScope), + ), + ), + ), + )) + } else if value != "true" { + log.Errorf("invalid bool '%s' for include_capi", value) + } + + return nil +} + func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) { predicates := make([]predicate.Alert, 0) @@ -739,16 +893,8 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ - /*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ - if v, ok := filter["simulated"]; ok { - if v[0] == "false" { - predicates = append(predicates, alert.SimulatedEQ(false)) - } - } - - if _, ok := filter["origin"]; ok { - filter["include_capi"] = []string{"true"} - } + handleSimulatedFilter(filter, &predicates) + handleOriginFilter(filter, &predicates) for param, value := range filter { switch param { @@ -758,14 +904,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } case "scope": - scope := value[0] - if strings.ToLower(scope) == "ip" { - scope = types.Ip - } else if strings.ToLower(scope) == "range" { - scope = types.Range - } - - predicates = append(predicates, alert.SourceScopeEQ(scope)) + handleScopeFilter(value[0], &predicates) case "value": predicates = append(predicates, alert.SourceValueEQ(value[0])) case "scenario": @@ -775,68 +914,18 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e if err != nil { return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err) } - case "since": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", since.String()) - } - - predicates = append(predicates, alert.StartedAtGTE(since)) - case "created_before": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } + case "since", "created_before", "until": + if err := handleTimeFilters(param, value[0], &predicates); err != nil { + return nil, err - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", since.String()) } - - predicates = append(predicates, alert.CreatedAtLTE(since)) - case "until": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - - until := time.Now().UTC().Add(-duration) - if until.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", until.String()) - } - - predicates = append(predicates, alert.StartedAtLTE(until)) case "decision_type": predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0]))) case "origin": predicates = append(predicates, alert.HasDecisionsWith(decision.OriginEQ(value[0]))) case "include_capi": // allows to exclude one or more specific origins - if value[0] == "false" { - predicates = append(predicates, alert.And( - // do not show alerts with active decisions having origin CAPI or lists - alert.And( - alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.CAPIOrigin))), - alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.ListOrigin))), - ), - alert.Not( - alert.And( - // do not show neither alerts with no decisions if the Source Scope is lists: or CAPI - alert.Not(alert.HasDecisions()), - alert.Or( - alert.SourceScopeHasPrefix(types.ListOrigin+":"), - alert.SourceScopeEQ(types.CommunityBlocklistPullSourceScope), - ), - ), - ), - ), - ) - } else if value[0] != "true" { - log.Errorf("Invalid bool '%s' for include_capi", value[0]) + if err = handleIncludeCapiFilter(value[0], &predicates); err != nil { + return nil, err } case "has_active_decision": if hasActiveDecision, err = strconv.ParseBool(value[0]); err != nil { @@ -861,72 +950,9 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e } } - if ip_sz == 4 { - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } - } else if ip_sz == 16 { - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - // matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - // decision.start_ip < query.start_ip - alert.HasDecisionsWith(decision.StartIPLT(start_ip)), - alert.And( - // decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - // decision.start_suffix <= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), - )), - alert.Or( - // decision.end_ip > query.end_ip - alert.HasDecisionsWith(decision.EndIPGT(end_ip)), - alert.And( - // decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - // decision.end_suffix >= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), - ), - ), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - // matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - // decision.start_ip > query.start_ip - alert.HasDecisionsWith(decision.StartIPGT(start_ip)), - alert.And( - // decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - // decision.start_suffix >= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), - )), - alert.Or( - // decision.end_ip < query.end_ip - alert.HasDecisionsWith(decision.EndIPLT(end_ip)), - alert.And( - // decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - // decision.end_suffix <= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), - ), - ), - )) - } - } else if ip_sz != 0 { - return nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil { + return nil, err + } return predicates, nil From 965e1c6cd71b6d1b748cec8c4c0c3a60446fe508 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 09:31:51 +0200 Subject: [PATCH 03/28] context propagation: explicit ctx parameter in unit tests This is done so we can later enable a context linter. --- pkg/acquisition/modules/loki/loki_test.go | 16 ++++++--- pkg/apiserver/alerts_test.go | 31 ++++++++++++------ pkg/apiserver/api_key_test.go | 9 +++-- pkg/apiserver/apiserver_test.go | 16 ++++++--- pkg/apiserver/jwt_test.go | 17 ++++++---- pkg/apiserver/machines_test.go | 40 +++++++++++++++-------- 6 files changed, 87 insertions(+), 42 deletions(-) diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index 5f41cd4c62e..ce86a1c36d9 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -261,7 +261,7 @@ func TestConfigureDSN(t *testing.T) { } } -func feedLoki(logger *log.Entry, n int, title string) error { +func feedLoki(ctx context.Context, logger *log.Entry, n int, title string) error { streams := LogStreams{ Streams: []LogStream{ { @@ -286,7 +286,7 @@ func feedLoki(logger *log.Entry, n int, title string) error { return err } - req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) if err != nil { return err } @@ -349,7 +349,9 @@ since: 1h t.Fatalf("Unexpected error : %s", err) } - err = feedLoki(subLogger, 20, title) + ctx := context.Background() + + err = feedLoki(ctx, subLogger, 20, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -421,6 +423,8 @@ query: > }, } + ctx := context.Background() + for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { logger := log.New() @@ -472,7 +476,7 @@ query: > } }) - err = feedLoki(subLogger, ts.expectedLines, title) + err = feedLoki(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -525,7 +529,9 @@ query: > time.Sleep(time.Second * 2) - err = feedLoki(subLogger, 1, title) + ctx := context.Background() + + err = feedLoki(ctx, subLogger, 1, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 891eb3a8f4a..92067554d65 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "net/http" @@ -45,8 +46,9 @@ func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.Response } func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { + ctx := context.Background() w := httptest.NewRecorder() - req, err := http.NewRequest(verb, url, body) + req, err := http.NewRequestWithContext(ctx, verb, url, body) require.NoError(t, err) switch authType { @@ -74,8 +76,9 @@ func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) mo body := CreateTestMachine(t, router, "") ValidateMachine(t, "test", config.API.Server.DbConfig) + ctx := context.Background() w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -355,9 +358,11 @@ func TestCreateAlertErrors(t *testing.T) { lapi := SetupLAPITest(t) alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + //test invalid bearer w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "ratata")) lapi.router.ServeHTTP(w, req) @@ -365,7 +370,7 @@ func TestCreateAlertErrors(t *testing.T) { //test invalid bearer w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) lapi.router.ServeHTTP(w, req) @@ -376,9 +381,11 @@ func TestDeleteAlert(t *testing.T) { lapi := SetupLAPITest(t) lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -387,7 +394,7 @@ func TestDeleteAlert(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -399,9 +406,11 @@ func TestDeleteAlertByID(t *testing.T) { lapi := SetupLAPITest(t) lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -410,7 +419,7 @@ func TestDeleteAlertByID(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -439,9 +448,11 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { loginResp: loginResp, } + ctx := context.Background() + assertAlertDeleteFailedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -453,7 +464,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assertAlertDeletedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index 883ff21298d..10e75ae47f1 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -12,11 +13,13 @@ import ( func TestAPIKey(t *testing.T) { router, config := NewAPITest(t) + ctx := context.Background() + APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) // Login with empty token w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -25,7 +28,7 @@ func TestAPIKey(t *testing.T) { // Login with invalid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", "a1b2c3d4e5f6") router.ServeHTTP(w, req) @@ -35,7 +38,7 @@ func TestAPIKey(t *testing.T) { // Login with valid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", APIKey) router.ServeHTTP(w, req) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index f48791ebcb8..89c75f35d21 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -278,8 +278,10 @@ func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string { body := string(b) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -323,8 +325,10 @@ func TestWithWrongFlushConfig(t *testing.T) { func TestUnknownPath(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -380,8 +384,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, api) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) assert.Equal(t, 404, w.Code) @@ -430,8 +436,10 @@ func TestLoggingErrorToFileConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, api) + ctx := context.Background() + w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index aa6e84e416b..7ef010ae12b 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -12,11 +13,13 @@ import ( func TestLogin(t *testing.T) { router, config := NewAPITest(t) + ctx := context.Background() + body := CreateTestMachine(t, router, "") // Login with machine not validated yet w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -25,7 +28,7 @@ func TestLogin(t *testing.T) { // Login with machine not exist w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -34,7 +37,7 @@ func TestLogin(t *testing.T) { // Login with invalid body w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -43,7 +46,7 @@ func TestLogin(t *testing.T) { // Login with invalid format w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -55,7 +58,7 @@ func TestLogin(t *testing.T) { // Login with invalid password w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -64,7 +67,7 @@ func TestLogin(t *testing.T) { // Login with valid machine w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -74,7 +77,7 @@ func TestLogin(t *testing.T) { // Login with valid machine + scenarios w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 041a6bee528..61677ca3dc4 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -16,9 +17,11 @@ import ( func TestCreateMachine(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + // Create machine with invalid format w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("test")) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -27,7 +30,7 @@ func TestCreateMachine(t *testing.T) { // Create machine with invalid input w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -41,7 +44,7 @@ func TestCreateMachine(t *testing.T) { body := string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -52,6 +55,9 @@ func TestCreateMachine(t *testing.T) { func TestCreateMachineWithForwardedFor(t *testing.T) { router, config := NewAPITestForwardedFor(t) router.TrustedPlatform = "X-Real-IP" + + ctx := context.Background() + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -59,7 +65,7 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-Ip", "1.1.1.1") router.ServeHTTP(w, req) @@ -75,6 +81,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { router, config := NewAPITest(t) + ctx := context.Background() + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -82,7 +90,7 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-IP", "1.1.1.1") router.ServeHTTP(w, req) @@ -100,6 +108,8 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { func TestCreateMachineWithoutForwardedFor(t *testing.T) { router, config := NewAPITestForwardedFor(t) + ctx := context.Background() + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -107,7 +117,7 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -124,15 +134,17 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { func TestCreateMachineAlreadyExist(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + body := CreateTestMachine(t, router, "") w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -143,6 +155,8 @@ func TestCreateMachineAlreadyExist(t *testing.T) { func TestAutoRegistration(t *testing.T) { router, _ := NewAPITest(t) + ctx := context.Background() + //Invalid registration token / valid source IP regReq := MachineTest regReq.RegistrationToken = invalidRegistrationToken @@ -152,7 +166,7 @@ func TestAutoRegistration(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) @@ -168,7 +182,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "42.42.42.42:4242" router.ServeHTTP(w, req) @@ -184,7 +198,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "42.42.42.42:4242" router.ServeHTTP(w, req) @@ -200,7 +214,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) @@ -216,7 +230,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) From 4f10b2d29cd28bb28f670e915805eab1df0d526f Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 10:30:53 +0200 Subject: [PATCH 04/28] context propagation: pass context to NewAPIC() --- cmd/crowdsec-cli/clipapi/papi.go | 12 +++++++----- pkg/apiserver/apic.go | 4 ++-- pkg/apiserver/apic_test.go | 4 +++- pkg/apiserver/apiserver.go | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index 747b8c01b9b..c0f08157f31 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -59,7 +59,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command { func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error { cfg := cli.cfg() - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } @@ -118,11 +118,11 @@ func (cli *cliPapi) newStatusCmd() *cobra.Command { return cmd } -func (cli *cliPapi) sync(out io.Writer, db *database.Client) error { +func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client) error { cfg := cli.cfg() t := tomb.Tomb{} - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } @@ -159,12 +159,14 @@ func (cli *cliPapi) newSyncCmd() *cobra.Command { DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { cfg := cli.cfg() - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) if err != nil { return err } - return cli.sync(color.Output, db) + return cli.sync(ctx, color.Output, db) }, } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 73061637ad9..3ed2e12ea54 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -174,7 +174,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) return signal } -func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { +func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { var err error ret := &apic{ @@ -237,7 +237,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con return ret, fmt.Errorf("get scenario in db: %w", err) } - authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &config.Credentials.Login, Password: &password, Scenarios: scenarios, diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 058e25079e0..105d295dd0d 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -230,6 +230,8 @@ func TestNewAPIC(t *testing.T) { }, } + ctx := context.Background() + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { setConfig() @@ -246,7 +248,7 @@ func TestNewAPIC(t *testing.T) { ), )) tc.action() - _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) + _, err := NewAPIC(ctx, testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 42dcb219379..8bf406e0a79 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -249,7 +249,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { log.Printf("Loading CAPI manager") - apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) + apiClient, err = NewAPIC(ctx, config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { return nil, err } From 66c3f0ff4e7dabcd567ea1be932301ece1dd228f Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 10:45:41 +0200 Subject: [PATCH 05/28] context propagation: drop field S3Source.ctx, pass explicitly --- pkg/acquisition/modules/s3/s3.go | 41 +++++++++++++++++--------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index a9835ab4974..38a3a835c47 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -55,7 +55,6 @@ type S3Source struct { readerChan chan S3Object t *tomb.Tomb out chan types.Event - ctx aws.Context cancel context.CancelFunc } @@ -184,7 +183,7 @@ func (s *S3Source) newSQSClient() error { return nil } -func (s *S3Source) readManager() { +func (s *S3Source) readManager(ctx context.Context) { logger := s.logger.WithField("method", "readManager") for { select { @@ -194,7 +193,7 @@ func (s *S3Source) readManager() { return case s3Object := <-s.readerChan: logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key) - err := s.readFile(s3Object.Bucket, s3Object.Key) + err := s.readFile(ctx, s3Object.Bucket, s3Object.Key) if err != nil { logger.Errorf("Error while reading file: %s", err) } @@ -202,13 +201,13 @@ func (s *S3Source) readManager() { } } -func (s *S3Source) getBucketContent() ([]*s3.Object, error) { +func (s *S3Source) getBucketContent(ctx context.Context) ([]*s3.Object, error) { logger := s.logger.WithField("method", "getBucketContent") logger.Debugf("Getting bucket content for %s", s.Config.BucketName) bucketObjects := make([]*s3.Object, 0) var continuationToken *string for { - out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{ + out, err := s.s3Client.ListObjectsV2WithContext(ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(s.Config.BucketName), Prefix: aws.String(s.Config.Prefix), ContinuationToken: continuationToken, @@ -229,7 +228,7 @@ func (s *S3Source) getBucketContent() ([]*s3.Object, error) { return bucketObjects, nil } -func (s *S3Source) listPoll() error { +func (s *S3Source) listPoll(ctx context.Context) error { logger := s.logger.WithField("method", "listPoll") ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second) lastObjectDate := time.Now() @@ -243,7 +242,7 @@ func (s *S3Source) listPoll() error { return nil case <-ticker.C: newObject := false - bucketObjects, err := s.getBucketContent() + bucketObjects, err := s.getBucketContent(ctx) if err != nil { logger.Errorf("Error while getting bucket content: %s", err) continue @@ -325,7 +324,7 @@ func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, erro } } -func (s *S3Source) sqsPoll() error { +func (s *S3Source) sqsPoll(ctx context.Context) error { logger := s.logger.WithField("method", "sqsPoll") for { select { @@ -335,7 +334,7 @@ func (s *S3Source) sqsPoll() error { return nil default: logger.Trace("Polling SQS queue") - out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{ + out, err := s.sqsClient.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: aws.Int64(10), WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ? @@ -378,7 +377,7 @@ func (s *S3Source) sqsPoll() error { } } -func (s *S3Source) readFile(bucket string, key string) error { +func (s *S3Source) readFile(ctx context.Context, bucket string, key string) error { //TODO: Handle SSE-C var scanner *bufio.Scanner @@ -388,7 +387,7 @@ func (s *S3Source) readFile(bucket string, key string) error { "key": key, }) - output, err := s.s3Client.GetObjectWithContext(s.ctx, &s3.GetObjectInput{ + output, err := s.s3Client.GetObjectWithContext(ctx, &s3.GetObjectInput{ Bucket: aws.String(bucket), Key: aws.String(key), }) @@ -645,24 +644,26 @@ func (s *S3Source) GetName() string { } func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + var ctx context.Context + s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key) s.out = out - s.ctx, s.cancel = context.WithCancel(context.Background()) + ctx, s.cancel = context.WithCancel(context.Background()) s.Config.UseTimeMachine = true s.t = t if s.Config.Key != "" { - err := s.readFile(s.Config.BucketName, s.Config.Key) + err := s.readFile(ctx, s.Config.BucketName, s.Config.Key) if err != nil { return err } } else { //No key, get everything in the bucket based on the prefix - objects, err := s.getBucketContent() + objects, err := s.getBucketContent(ctx) if err != nil { return err } for _, object := range objects { - err := s.readFile(s.Config.BucketName, *object.Key) + err := s.readFile(ctx, s.Config.BucketName, *object.Key) if err != nil { return err } @@ -673,18 +674,20 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error } func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { + var ctx context.Context + s.t = t s.out = out s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(context.Background()) + ctx, s.cancel = context.WithCancel(context.Background()) s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) t.Go(func() error { - s.readManager() + s.readManager(ctx) return nil }) if s.Config.PollingMethod == PollMethodSQS { t.Go(func() error { - err := s.sqsPoll() + err := s.sqsPoll(ctx) if err != nil { return err } @@ -692,7 +695,7 @@ func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) erro }) } else { t.Go(func() error { - err := s.listPoll() + err := s.listPoll(ctx) if err != nil { return err } From 77c269a2ee240695558c9c3fa74ae9ae216d6587 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 13:16:53 +0200 Subject: [PATCH 06/28] context propagation: pkg/database/flush --- cmd/crowdsec-cli/clialert/alerts.go | 6 ++-- pkg/apiserver/apiserver.go | 2 +- pkg/database/flush.go | 43 +++++++++++++++-------------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/cmd/crowdsec-cli/clialert/alerts.go b/cmd/crowdsec-cli/clialert/alerts.go index 0965e1e13d0..dbb7ca14db5 100644 --- a/cmd/crowdsec-cli/clialert/alerts.go +++ b/cmd/crowdsec-cli/clialert/alerts.go @@ -575,15 +575,17 @@ func (cli *cliAlerts) newFlushCmd() *cobra.Command { DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { cfg := cli.cfg() + ctx := cmd.Context() + if err := require.LAPI(cfg); err != nil { return err } - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) + db, err := require.DBClient(ctx, cfg.DbConfig) if err != nil { return err } log.Info("Flushing alerts. !! This may take a long time !!") - err = db.FlushAlerts(maxAge, maxItems) + err = db.FlushAlerts(ctx, maxAge, maxItems) if err != nil { return fmt.Errorf("unable to flush alerts: %w", err) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 8bf406e0a79..95d18ccb028 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -170,7 +170,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { } if config.DbConfig.Flush != nil { - flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) + flushScheduler, err = dbClient.StartFlushScheduler(ctx, config.DbConfig.Flush) if err != nil { return nil, err } diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 5d53d10c942..46c8edfa308 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -1,6 +1,7 @@ package database import ( + "context" "errors" "fmt" "time" @@ -26,7 +27,7 @@ const ( flushInterval = 1 * time.Minute ) -func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { +func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { maxItems := 0 maxAge := "" @@ -45,7 +46,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched // Init & Start cronjob every minute for alerts scheduler := gocron.NewScheduler(time.UTC) - job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, ctx, maxAge, maxItems) if err != nil { return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) } @@ -100,14 +101,14 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched } } - baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) + baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, ctx, config.AgentsGC, config.BouncersGC) if err != nil { return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) } baJob.SingletonMode() - metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, config.MetricsMaxAge) + metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, ctx, config.MetricsMaxAge) if err != nil { return nil, fmt.Errorf("while starting flushMetrics scheduler: %w", err) } @@ -120,7 +121,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched } // flushMetrics deletes metrics older than maxAge, regardless if they have been pushed to CAPI or not -func (c *Client) flushMetrics(maxAge *time.Duration) { +func (c *Client) flushMetrics(ctx context.Context, maxAge *time.Duration) { if maxAge == nil { maxAge = ptr.Of(defaultMetricsMaxAge) } @@ -129,7 +130,7 @@ func (c *Client) flushMetrics(maxAge *time.Duration) { deleted, err := c.Ent.Metric.Delete().Where( metric.ReceivedAtLTE(time.Now().UTC().Add(-*maxAge)), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while flushing metrics: %s", err) return @@ -140,10 +141,10 @@ func (c *Client) flushMetrics(maxAge *time.Duration) { } } -func (c *Client) FlushOrphans() { +func (c *Client) FlushOrphans(ctx context.Context) { /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ - eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) + eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan events: %s", err) return @@ -154,7 +155,7 @@ func (c *Client) FlushOrphans() { } eventsCount, err = c.Ent.Decision.Delete().Where( - decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) + decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan decisions: %s", err) return @@ -165,7 +166,7 @@ func (c *Client) FlushOrphans() { } } -func (c *Client) flushBouncers(authType string, duration *time.Duration) { +func (c *Client) flushBouncers(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -174,7 +175,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { bouncer.LastPullLTE(time.Now().UTC().Add(-*duration)), ).Where( bouncer.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired bouncers (%s): %s", authType, err) return @@ -185,7 +186,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { } } -func (c *Client) flushAgents(authType string, duration *time.Duration) { +func (c *Client) flushAgents(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -194,7 +195,7 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { machine.LastHeartbeatLTE(time.Now().UTC().Add(-*duration)), machine.Not(machine.HasAlerts()), machine.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired machines (%s): %s", authType, err) return @@ -205,23 +206,23 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { } } -func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { +func (c *Client) FlushAgentsAndBouncers(ctx context.Context, agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { log.Debug("starting FlushAgentsAndBouncers") if agentsCfg != nil { - c.flushAgents(types.TlsAuthType, agentsCfg.CertDuration) - c.flushAgents(types.PasswordAuthType, agentsCfg.LoginPasswordDuration) + c.flushAgents(ctx, types.TlsAuthType, agentsCfg.CertDuration) + c.flushAgents(ctx, types.PasswordAuthType, agentsCfg.LoginPasswordDuration) } if bouncersCfg != nil { - c.flushBouncers(types.TlsAuthType, bouncersCfg.CertDuration) - c.flushBouncers(types.ApiKeyAuthType, bouncersCfg.ApiDuration) + c.flushBouncers(ctx, types.TlsAuthType, bouncersCfg.CertDuration) + c.flushBouncers(ctx, types.ApiKeyAuthType, bouncersCfg.ApiDuration) } return nil } -func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { +func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) error { var ( deletedByAge int deletedByNbItem int @@ -235,7 +236,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { } c.Log.Debug("Flushing orphan alerts") - c.FlushOrphans() + c.FlushOrphans(ctx) c.Log.Debug("Done flushing orphan alerts") totalAlerts, err = c.TotalAlerts() @@ -287,7 +288,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { if maxid > 0 { // This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted - deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) + deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(ctx) if err != nil { c.Log.Errorf("FlushAlerts: Could not delete alerts: %s", err) return fmt.Errorf("could not delete alerts: %w", err) From f8ac86115ea3743321c30bfff7bc2f251d86e4c1 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 13:28:40 +0200 Subject: [PATCH 07/28] context propagation: bouncer list --- cmd/crowdsec-cli/clibouncer/bouncers.go | 14 ++++++++------ cmd/crowdsec-cli/clisupport/support.go | 6 +++--- pkg/apiserver/apic_metrics.go | 16 ++++++++++------ pkg/apiserver/apic_test.go | 8 +++++--- pkg/database/bouncers.go | 4 ++-- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go index 0c0fc8851c9..89e91b63911 100644 --- a/cmd/crowdsec-cli/clibouncer/bouncers.go +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -1,6 +1,7 @@ package clibouncer import ( + "context" "encoding/csv" "encoding/json" "errors" @@ -159,11 +160,11 @@ func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error { return nil } -func (cli *cliBouncers) List(out io.Writer, db *database.Client) error { +func (cli *cliBouncers) List(ctx context.Context, out io.Writer, db *database.Client) error { // XXX: must use the provided db object, the one in the struct might be nil // (calling List directly skips the PersistentPreRunE) - bouncers, err := db.ListBouncers() + bouncers, err := db.ListBouncers(ctx) if err != nil { return fmt.Errorf("unable to list bouncers: %w", err) } @@ -199,8 +200,8 @@ func (cli *cliBouncers) newListCmd() *cobra.Command { Example: `cscli bouncers list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.List(color.Output, cli.db) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) }, } @@ -271,6 +272,7 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp var err error cfg := cli.cfg() + ctx := cmd.Context() // need to load config and db because PersistentPreRunE is not called for completions @@ -279,13 +281,13 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp return nil, cobra.ShellCompDirectiveNoFileComp } - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + cli.db, err = require.DBClient(ctx, cfg.DbConfig) if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp } - bouncers, err := cli.db.ListBouncers() + bouncers, err := cli.db.ListBouncers(ctx) if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp diff --git a/cmd/crowdsec-cli/clisupport/support.go b/cmd/crowdsec-cli/clisupport/support.go index e9837b03fe7..7e41518805a 100644 --- a/cmd/crowdsec-cli/clisupport/support.go +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -189,7 +189,7 @@ func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub) error { return nil } -func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting bouncers") if db == nil { @@ -199,7 +199,7 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { out := new(bytes.Buffer) cb := clibouncer.New(cli.cfg) - if err := cb.List(out, db); err != nil { + if err := cb.List(ctx, out, db); err != nil { return err } @@ -525,7 +525,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error { log.Warnf("could not collect hub information: %s", err) } - if err = cli.dumpBouncers(zipWriter, db); err != nil { + if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil { log.Warnf("could not collect bouncers information: %s", err) } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 5c6a550a6a0..91a0a8273f7 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -23,7 +23,7 @@ type dbPayload struct { Metrics []*models.DetailedMetrics `json:"metrics"` } -func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { +func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, error) { allMetrics := &models.AllMetrics{} metricsIds := make([]int, 0) @@ -32,7 +32,7 @@ func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { return nil, nil, err } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, nil, err } @@ -185,7 +185,7 @@ func (a *apic) MarkUsageMetricsAsSent(ids []int) error { return a.dbClient.MarkUsageMetricsAsSent(ids) } -func (a *apic) GetMetrics() (*models.Metrics, error) { +func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { machines, err := a.dbClient.ListMachines() if err != nil { return nil, err @@ -202,7 +202,7 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { } } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, err } @@ -254,6 +254,8 @@ func (a *apic) fetchMachineIDs() ([]string, error) { func (a *apic) SendMetrics(stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") + ctx := context.TODO() + // verify the list of machines every interval const checkInt = 20 * time.Second @@ -311,7 +313,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-metTicker.C: metTicker.Stop() - metrics, err := a.GetMetrics() + metrics, err := a.GetMetrics(ctx) if err != nil { log.Errorf("unable to get metrics (%s)", err) } @@ -340,6 +342,8 @@ func (a *apic) SendMetrics(stop chan (bool)) { func (a *apic) SendUsageMetrics() { defer trace.CatchPanic("lapi/usageMetricsToAPIC") + ctx := context.TODO() + firstRun := true log.Debugf("Start sending usage metrics to CrowdSec Central API (interval: %s once, then %s)", a.usageMetricsIntervalFirst, a.usageMetricsInterval) @@ -358,7 +362,7 @@ func (a *apic) SendUsageMetrics() { ticker.Reset(a.usageMetricsInterval) } - metrics, metricsId, err := a.GetUsageMetrics() + metrics, metricsId, err := a.GetUsageMetrics(ctx) if err != nil { log.Errorf("unable to get usage metrics: %s", err) continue diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 105d295dd0d..182bf18532f 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -292,9 +292,11 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { } func TestAPICGetMetrics(t *testing.T) { + ctx := context.Background() + cleanUp := func(api *apic) { - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) } tests := []struct { name string @@ -377,7 +379,7 @@ func TestAPICGetMetrics(t *testing.T) { ExecX(context.Background()) } - foundMetrics, err := apiClient.GetMetrics() + foundMetrics, err := apiClient.GetMetrics(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index a7378bbb203..6ff308ff786 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -59,8 +59,8 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().All(c.CTX) +func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) } From aba6dfd59f533ca06f0059e859ef437a81decbe9 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 10:35:06 +0200 Subject: [PATCH 08/28] context propagation: pkg/database/config --- cmd/crowdsec-cli/clipapi/papi.go | 2 +- pkg/apiserver/apic.go | 28 ++++++++++++++-------------- pkg/apiserver/apic_test.go | 24 +++++++++++++++--------- pkg/apiserver/apiserver.go | 18 ++++++++++-------- pkg/apiserver/papi.go | 8 ++++---- pkg/apiserver/papi_cmd.go | 7 +++++-- pkg/database/config.go | 12 ++++++------ 7 files changed, 55 insertions(+), 44 deletions(-) diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index c0f08157f31..b8101a0fb34 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -74,7 +74,7 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie return fmt.Errorf("unable to get PAPI permissions: %w", err) } - lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey) + lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey) if err != nil { lastTimestampStr = ptr.Of("never") } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 3ed2e12ea54..b5384c6cc5c 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -614,7 +614,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio // we receive a list of decisions and links for blocklist and we need to create a list of alerts : // one alert for "community blocklist" // one alert per list we're subscribed to -func (a *apic) PullTop(forcePull bool) error { +func (a *apic) PullTop(ctx context.Context, forcePull bool) error { var err error // A mutex with TryLock would be a bit simpler @@ -655,7 +655,7 @@ func (a *apic) PullTop(forcePull bool) error { log.Infof("Starting community-blocklist update") - data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup}) + data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup}) if err != nil { return fmt.Errorf("get stream: %w", err) } @@ -700,7 +700,7 @@ func (a *apic) PullTop(forcePull bool) error { } // update blocklists - if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil { + if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } @@ -708,9 +708,9 @@ func (a *apic) PullTop(forcePull bool) error { } // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert -func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { +func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error { addCounters, _ := makeAddAndDeleteCounters() - if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ + if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{blocklist}, }, addCounters, forcePull); err != nil { return fmt.Errorf("while pulling blocklist: %w", err) @@ -820,7 +820,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo return false, nil } -func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { if blocklist.Scope == nil { log.Warningf("blocklist has no scope") return nil @@ -848,13 +848,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap ) if !forcePull { - lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName) if err != nil { return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } } - decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) + decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp) if err != nil { return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) } @@ -869,7 +869,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } - err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) if err != nil { return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } @@ -892,7 +892,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } @@ -908,7 +908,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } for _, blocklist := range links.Blocklists { - if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil { + if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil { return err } } @@ -931,7 +931,7 @@ func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int } } -func (a *apic) Pull() error { +func (a *apic) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/pullFromAPIC") toldOnce := false @@ -955,7 +955,7 @@ func (a *apic) Pull() error { time.Sleep(1 * time.Second) } - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -967,7 +967,7 @@ func (a *apic) Pull() error { case <-ticker.C: ticker.Reset(a.pullInterval) - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) continue } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 182bf18532f..97943b495e5 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -550,6 +550,7 @@ func TestFillAlertsWithDecisions(t *testing.T) { } func TestAPICWhitelists(t *testing.T) { + ctx := context.Background() api := getAPIC(t) // one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} @@ -685,7 +686,7 @@ func TestAPICWhitelists(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing @@ -736,6 +737,7 @@ func TestAPICWhitelists(t *testing.T) { } func TestAPICPullTop(t *testing.T) { + ctx := context.Background() api := getAPIC(t) api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). @@ -826,7 +828,7 @@ func TestAPICPullTop(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) @@ -860,6 +862,7 @@ func TestAPICPullTop(t *testing.T) { } func TestAPICPullTopBLCacheFirstCall(t *testing.T) { + ctx := context.Background() // no decision in db, no last modified parameter. api := getAPIC(t) @@ -913,11 +916,11 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) blocklistConfigItemName := "blocklist:blocklist1:last_pull" - lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.NotEqual(t, "", *lastPullTimestamp) @@ -927,14 +930,15 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { return httpmock.NewStringResponse(304, ""), nil }) - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) } func TestAPICPullTopBLCacheForceCall(t *testing.T) { + ctx := context.Background() api := getAPIC(t) httpmock.Activate() @@ -1005,11 +1009,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) } func TestAPICPullBlocklistCall(t *testing.T) { + ctx := context.Background() api := getAPIC(t) httpmock.Activate() @@ -1032,7 +1037,7 @@ func TestAPICPullBlocklistCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullBlocklist(&modelscapi.BlocklistLink{ + err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{ URL: ptr.Of("http://api.crowdsec.net/blocklist1"), Name: ptr.Of("blocklist1"), Scope: ptr.Of("Ip"), @@ -1134,6 +1139,7 @@ func TestAPICPush(t *testing.T) { } func TestAPICPull(t *testing.T) { + ctx := context.Background() api := getAPIC(t) tests := []struct { name string @@ -1204,7 +1210,7 @@ func TestAPICPull(t *testing.T) { go func() { logrus.SetOutput(&buf) - if err := api.Pull(); err != nil { + if err := api.Pull(ctx); err != nil { panic(err) } }() diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 95d18ccb028..6b5d6803be9 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -310,8 +310,8 @@ func (s *APIServer) apicPush() error { return nil } -func (s *APIServer) apicPull() error { - if err := s.apic.Pull(); err != nil { +func (s *APIServer) apicPull(ctx context.Context) error { + if err := s.apic.Pull(ctx); err != nil { log.Errorf("capi pull: %s", err) return err } @@ -319,8 +319,8 @@ func (s *APIServer) apicPull() error { return nil } -func (s *APIServer) papiPull() error { - if err := s.papi.Pull(); err != nil { +func (s *APIServer) papiPull(ctx context.Context) error { + if err := s.papi.Pull(ctx); err != nil { log.Errorf("papi pull: %s", err) return err } @@ -337,16 +337,16 @@ func (s *APIServer) papiSync() error { return nil } -func (s *APIServer) initAPIC() { +func (s *APIServer) initAPIC(ctx context.Context) { s.apic.pushTomb.Go(s.apicPush) - s.apic.pullTomb.Go(s.apicPull) + s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) }) // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios if s.apic.apiClient.IsEnrolled() { if s.consoleConfig.IsPAPIEnabled() { if s.papi.URL != "" { log.Info("Starting PAPI decision receiver") - s.papi.pullTomb.Go(s.papiPull) + s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) }) s.papi.syncTomb.Go(s.papiSync) } else { log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") @@ -381,8 +381,10 @@ func (s *APIServer) Run(apiReady chan bool) error { TLSConfig: tlsCfg, } + ctx := context.TODO() + if s.apic != nil { - s.initAPIC() + s.initAPIC(ctx) } s.httpServerTomb.Go(func() error { diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 89ad93930a1..7dd6b346aa9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -230,13 +230,13 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { } // PullPAPI is the long polling client for real-time decisions from PAPI -func (p *Papi) Pull() error { +func (p *Papi) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/PullPAPI") p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} - lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) + lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } @@ -248,7 +248,7 @@ func (p *Papi) Pull() error { return fmt.Errorf("failed to serialize last timestamp: %w", err) } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) @@ -277,7 +277,7 @@ func (p *Papi) Pull() error { continue } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index a1137161698..943eb4139de 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "time" @@ -215,17 +216,19 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } + ctx := context.TODO() + if forcePullMsg.Blocklist == nil { p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err = p.apic.PullTop(true) + err = p.apic.PullTop(ctx, true) if err != nil { return fmt.Errorf("failed to force pull operation: %w", err) } } else { p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) - err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{ + err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{ Name: &forcePullMsg.Blocklist.Name, URL: &forcePullMsg.Blocklist.Url, Remediation: &forcePullMsg.Blocklist.Remediation, diff --git a/pkg/database/config.go b/pkg/database/config.go index 8c3578ad596..7c341d3ecda 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -1,14 +1,15 @@ package database import ( + "context" "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(key string) (*string, error) { - result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX) +func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { + result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) if err != nil && ent.IsNotFound(err) { return nil, nil } @@ -19,11 +20,10 @@ func (c *Client) GetConfigItem(key string) (*string, error) { return &result.Value, nil } -func (c *Client) SetConfigItem(key string, value string) error { - - nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX) +func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { + nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create - err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX) + err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) } From 37c01098c7a581132ee20a83e7f8bad8eb977155 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 11:10:35 +0200 Subject: [PATCH 09/28] context propagation: pkg/database/metrics --- pkg/apiserver/apic_metrics.go | 10 +++++----- pkg/apiserver/usage_metrics_test.go | 8 ++++++-- pkg/database/metrics.go | 12 ++++++------ 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 91a0a8273f7..e5821e4c1e2 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -38,7 +38,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, } for _, bouncer := range bouncers { - dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(bouncer.Name) + dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(ctx, bouncer.Name) if err != nil { log.Errorf("unable to get bouncer usage metrics: %s", err) continue @@ -81,7 +81,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, } for _, lp := range lps { - dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(lp.MachineId) + dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(ctx, lp.MachineId) if err != nil { log.Errorf("unable to get LP usage metrics: %s", err) continue @@ -181,8 +181,8 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, return allMetrics, metricsIds, nil } -func (a *apic) MarkUsageMetricsAsSent(ids []int) error { - return a.dbClient.MarkUsageMetricsAsSent(ids) +func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { + return a.dbClient.MarkUsageMetricsAsSent(ctx, ids) } func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { @@ -379,7 +379,7 @@ func (a *apic) SendUsageMetrics() { } } - err = a.MarkUsageMetricsAsSent(metricsId) + err = a.MarkUsageMetricsAsSent(ctx, metricsId) if err != nil { log.Errorf("unable to mark usage metrics as sent: %s", err) continue diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go index 41dd0ccdc2c..019de5fb970 100644 --- a/pkg/apiserver/usage_metrics_test.go +++ b/pkg/apiserver/usage_metrics_test.go @@ -13,6 +13,8 @@ import ( ) func TestLPMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string body string @@ -198,7 +200,7 @@ func TestLPMetrics(t *testing.T) { assert.Contains(t, w.Body.String(), tt.expectedResponse) machine, _ := dbClient.QueryMachineByID("test") - metrics, _ := dbClient.GetLPUsageMetricsByMachineID("test") + metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) assert.Equal(t, tt.expectedOSName, machine.Osname) @@ -214,6 +216,8 @@ func TestLPMetrics(t *testing.T) { } func TestRCMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string body string @@ -368,7 +372,7 @@ func TestRCMetrics(t *testing.T) { assert.Contains(t, w.Body.String(), tt.expectedResponse) bouncer, _ := dbClient.SelectBouncerByName("test") - metrics, _ := dbClient.GetBouncerUsageMetricsByName("test") + metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) assert.Equal(t, tt.expectedOSName, bouncer.Osname) diff --git a/pkg/database/metrics.go b/pkg/database/metrics.go index 1619fcc923b..99ba90c80b8 100644 --- a/pkg/database/metrics.go +++ b/pkg/database/metrics.go @@ -25,14 +25,14 @@ func (c *Client) CreateMetric(ctx context.Context, generatedType metric.Generate return metric, nil } -func (c *Client) GetLPUsageMetricsByMachineID(machineId string) ([]*ent.Metric, error) { +func (c *Client) GetLPUsageMetricsByMachineID(ctx context.Context, machineId string) ([]*ent.Metric, error) { metrics, err := c.Ent.Metric.Query(). Where( metric.GeneratedTypeEQ(metric.GeneratedTypeLP), metric.GeneratedByEQ(machineId), metric.PushedAtIsNil(), ). - All(c.CTX) + All(ctx) if err != nil { c.Log.Warningf("GetLPUsageMetricsByOrigin: %s", err) return nil, fmt.Errorf("getting LP usage metrics by origin %s: %w", machineId, err) @@ -41,14 +41,14 @@ func (c *Client) GetLPUsageMetricsByMachineID(machineId string) ([]*ent.Metric, return metrics, nil } -func (c *Client) GetBouncerUsageMetricsByName(bouncerName string) ([]*ent.Metric, error) { +func (c *Client) GetBouncerUsageMetricsByName(ctx context.Context, bouncerName string) ([]*ent.Metric, error) { metrics, err := c.Ent.Metric.Query(). Where( metric.GeneratedTypeEQ(metric.GeneratedTypeRC), metric.GeneratedByEQ(bouncerName), metric.PushedAtIsNil(), ). - All(c.CTX) + All(ctx) if err != nil { c.Log.Warningf("GetBouncerUsageMetricsByName: %s", err) return nil, fmt.Errorf("getting bouncer usage metrics by name %s: %w", bouncerName, err) @@ -57,11 +57,11 @@ func (c *Client) GetBouncerUsageMetricsByName(bouncerName string) ([]*ent.Metric return metrics, nil } -func (c *Client) MarkUsageMetricsAsSent(ids []int) error { +func (c *Client) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { _, err := c.Ent.Metric.Update(). Where(metric.IDIn(ids...)). SetPushedAt(time.Now().UTC()). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("MarkUsageMetricsAsSent: %s", err) return fmt.Errorf("marking usage metrics as sent: %w", err) From 1d54af39e965f5f96956d92ef1abbfe53236ab82 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 11:25:15 +0200 Subject: [PATCH 10/28] CreateMachine(ctx...) --- cmd/crowdsec-cli/climachine/machines.go | 9 +++++---- pkg/apiserver/controllers/v1/machines.go | 4 +++- pkg/apiserver/middlewares/v1/jwt.go | 6 ++++-- pkg/database/machines.go | 8 ++++---- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index 30948f43056..c96601afee7 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -1,6 +1,7 @@ package climachine import ( + "context" "encoding/csv" "encoding/json" "errors" @@ -278,8 +279,8 @@ func (cli *cliMachines) newAddCmd() *cobra.Command { cscli machines add MyTestMachine --auto cscli machines add MyTestMachine --password MyPassword cscli machines add -f- --auto > /tmp/mycreds.yaml`, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force) }, } @@ -294,7 +295,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`, return cmd } -func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { +func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { var ( err error machineID string @@ -353,7 +354,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri password := strfmt.Password(machinePassword) - _, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType) + _, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType) if err != nil { return fmt.Errorf("unable to create machine: %w", err) } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 0030f7d3b39..ff59e389cb1 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -46,6 +46,8 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, } func (c *Controller) CreateMachine(gctx *gin.Context) { + ctx := gctx.Request.Context() + var input models.WatcherRegistrationRequest if err := gctx.ShouldBindJSON(&input); err != nil { @@ -66,7 +68,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) { return } - if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { + if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 64406deff3e..491dc9768cd 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -55,6 +55,8 @@ type authInput struct { } func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { + // XXX: should we pass ctx instead + ctx := c.Request.Context() ret := authInput{} if j.TlsAuth == nil { @@ -76,7 +78,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if ent.IsNotFound(err) { // Machine was not found, let's create it logger.Infof("machine %s not found, create it", ret.machineID) @@ -91,7 +93,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { password := strfmt.Password(pwd) - ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) + ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 27d737e625e..1f14f356ad0 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -72,7 +72,7 @@ func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string, return nil } -func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { +func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { c.Log.Warningf("CreateMachine: %s", err) @@ -82,14 +82,14 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA machineExist, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(*machineID)). - Select(machine.FieldMachineId).Strings(c.CTX) + Select(machine.FieldMachineId).Strings(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } if len(machineExist) > 0 { if force { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) @@ -113,7 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA SetIpAddress(ipAddress). SetIsValidated(isValidated). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) From a2358e7ca5650cb627ef916562696814cf317328 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 11:37:54 +0200 Subject: [PATCH 11/28] QueryMachineByID(ctx...) --- cmd/crowdsec-cli/climachine/machines.go | 6 ++++-- pkg/apiserver/usage_metrics_test.go | 2 +- pkg/database/alerts.go | 4 +++- pkg/database/machines.go | 6 +++--- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index c96601afee7..11269171122 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -691,9 +691,11 @@ func (cli *cliMachines) newInspectCmd() *cobra.Command { Args: cobra.ExactArgs(1), DisableAutoGenTag: true, ValidArgsFunction: cli.validMachineID, - RunE: func(_ *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() machineID := args[0] - machine, err := cli.db.QueryMachineByID(machineID) + + machine, err := cli.db.QueryMachineByID(ctx, machineID) if err != nil { return fmt.Errorf("unable to read machine data '%s': %w", machineID, err) } diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go index 019de5fb970..b335738ea16 100644 --- a/pkg/apiserver/usage_metrics_test.go +++ b/pkg/apiserver/usage_metrics_test.go @@ -199,7 +199,7 @@ func TestLPMetrics(t *testing.T) { assert.Equal(t, tt.expectedStatusCode, w.Code) assert.Contains(t, w.Body.String(), tt.expectedResponse) - machine, _ := dbClient.QueryMachineByID("test") + machine, _ := dbClient.QueryMachineByID(ctx, "test") metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 3e3e480c7d6..19504983cbe 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -687,8 +687,10 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str err error ) + ctx := context.TODO() + if machineID != "" { - owner, err = c.QueryMachineByID(machineID) + owner, err = c.QueryMachineByID(ctx, machineID) if err != nil { if !errors.Is(err, UserNotExists) { return nil, fmt.Errorf("machine '%s': %w", machineID, err) diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 1f14f356ad0..6412761f390 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -95,7 +95,7 @@ func (c *Client) CreateMachine(ctx context.Context, machineID *string, password return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } - machine, err := c.QueryMachineByID(*machineID) + machine, err := c.QueryMachineByID(ctx, *machineID) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } @@ -122,11 +122,11 @@ func (c *Client) CreateMachine(ctx context.Context, machineID *string, password return machine, nil } -func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { +func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) { machine, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(machineID)). - Only(c.CTX) + Only(ctx) if err != nil { c.Log.Warningf("QueryMachineByID : %s", err) return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID) From ed0646ece9e03e424bea82fef97a99a1626ea1c2 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 11:45:29 +0200 Subject: [PATCH 12/28] ListMachines(ctx) --- cmd/crowdsec-cli/climachine/machines.go | 13 +++++++------ cmd/crowdsec-cli/clisupport/support.go | 6 +++--- pkg/apiserver/apic.go | 4 +++- pkg/apiserver/apic_metrics.go | 10 +++++----- pkg/apiserver/apiserver_test.go | 2 +- pkg/database/machines.go | 4 ++-- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index 11269171122..a659fc04a62 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -211,11 +211,11 @@ func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error { return nil } -func (cli *cliMachines) List(out io.Writer, db *database.Client) error { +func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error { // XXX: must use the provided db object, the one in the struct might be nil // (calling List directly skips the PersistentPreRunE) - machines, err := db.ListMachines() + machines, err := db.ListMachines(ctx) if err != nil { return fmt.Errorf("unable to list machines: %w", err) } @@ -252,8 +252,8 @@ func (cli *cliMachines) newListCmd() *cobra.Command { Example: `cscli machines list`, Args: cobra.NoArgs, DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.List(color.Output, cli.db) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) }, } @@ -400,6 +400,7 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp var err error cfg := cli.cfg() + ctx := cmd.Context() // need to load config and db because PersistentPreRunE is not called for completions @@ -408,13 +409,13 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp return nil, cobra.ShellCompDirectiveNoFileComp } - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + cli.db, err = require.DBClient(ctx, cfg.DbConfig) if err != nil { cobra.CompError("unable to list machines " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp } - machines, err := cli.db.ListMachines() + machines, err := cli.db.ListMachines(ctx) if err != nil { cobra.CompError("unable to list machines " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp diff --git a/cmd/crowdsec-cli/clisupport/support.go b/cmd/crowdsec-cli/clisupport/support.go index 7e41518805a..4474f5c8f11 100644 --- a/cmd/crowdsec-cli/clisupport/support.go +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -210,7 +210,7 @@ func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *dat return nil } -func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting agents") if db == nil { @@ -220,7 +220,7 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { out := new(bytes.Buffer) cm := climachine.New(cli.cfg) - if err := cm.List(out, db); err != nil { + if err := cm.List(ctx, out, db); err != nil { return err } @@ -529,7 +529,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error { log.Warnf("could not collect bouncers information: %s", err) } - if err = cli.dumpAgents(zipWriter, db); err != nil { + if err = cli.dumpAgents(ctx, zipWriter, db); err != nil { log.Warnf("could not collect agents information: %s", err) } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index b5384c6cc5c..c79d5f88e3f 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -85,7 +85,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { func (a *apic) FetchScenariosListFromDB() ([]string, error) { scenarios := make([]string, 0) - machines, err := a.dbClient.ListMachines() + ctx := context.TODO() + + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index e5821e4c1e2..16b2328dbe9 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -27,7 +27,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, allMetrics := &models.AllMetrics{} metricsIds := make([]int, 0) - lps, err := a.dbClient.ListMachines() + lps, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, nil, err } @@ -186,7 +186,7 @@ func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { } func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { - machines, err := a.dbClient.ListMachines() + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -230,8 +230,8 @@ func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { }, nil } -func (a *apic) fetchMachineIDs() ([]string, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -277,7 +277,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { machineIDs := []string{} reloadMachineIDs := func() { - ids, err := a.fetchMachineIDs() + ids, err := a.fetchMachineIDs(ctx) if err != nil { log.Debugf("unable to get machines (%s), will retry", err) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 89c75f35d21..1f2f7fdf972 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -197,7 +197,7 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - machines, err := dbClient.ListMachines() + machines, err := dbClient.ListMachines(ctx) require.NoError(t, err) for _, machine := range machines { diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 6412761f390..8ec4f5f5872 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -135,8 +135,8 @@ func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.M return machine, nil } -func (c *Client) ListMachines() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().All(c.CTX) +func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } From 2654d2c143b1c53c535b401aa147a2522bcfaafb Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 12:08:07 +0200 Subject: [PATCH 13/28] ValidateMachine(ctx...) --- cmd/crowdsec-cli/climachine/machines.go | 8 ++++---- pkg/apiserver/apiserver_test.go | 4 ++-- pkg/database/machines.go | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index a659fc04a62..bc57eb073c4 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -555,8 +555,8 @@ cscli machines prune --not-validated-only --force`, return cmd } -func (cli *cliMachines) validate(machineID string) error { - if err := cli.db.ValidateMachine(machineID); err != nil { +func (cli *cliMachines) validate(ctx context.Context, machineID string) error { + if err := cli.db.ValidateMachine(ctx, machineID); err != nil { return fmt.Errorf("unable to validate machine '%s': %w", machineID, err) } @@ -573,8 +573,8 @@ func (cli *cliMachines) newValidateCmd() *cobra.Command { Example: `cscli machines validate "machine_name"`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.validate(args[0]) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(cmd.Context(), args[0]) }, } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 1f2f7fdf972..9e12d27cb36 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -182,12 +182,12 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { } func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { - ctx := context.Background() + ctx := context.TODO() dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - err = dbClient.ValidateMachine(machineID) + err = dbClient.ValidateMachine(ctx, machineID) require.NoError(t, err) } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 8ec4f5f5872..c93a3c2fb50 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -144,8 +144,8 @@ func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) { return machines, nil } -func (c *Client) ValidateMachine(machineID string) error { - rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX) +func (c *Client) ValidateMachine(ctx context.Context, machineID string) error { + rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "validating machine: %s", err) } From 9d9fc23a6d94510f30f4818275f878d61414ca6f Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 12:10:28 +0200 Subject: [PATCH 14/28] QueryPendingMachine(ctx) --- cmd/crowdsec-cli/climachine/machines.go | 8 ++++---- pkg/database/machines.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index bc57eb073c4..09e8d9563a0 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -473,7 +473,7 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command { return cmd } -func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force bool) error { +func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error { if duration < 2*time.Minute && !notValidOnly { if yes, err := ask.YesNo( "The duration you provided is less than 2 minutes. "+ @@ -486,7 +486,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b } machines := []*ent.Machine{} - if pending, err := cli.db.QueryPendingMachine(); err == nil { + if pending, err := cli.db.QueryPendingMachine(ctx); err == nil { machines = append(machines, pending...) } @@ -542,8 +542,8 @@ cscli machines prune --duration 1h cscli machines prune --not-validated-only --force`, Args: cobra.NoArgs, DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, notValidOnly, force) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, notValidOnly, force) }, } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index c93a3c2fb50..1318e199add 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -157,8 +157,8 @@ func (c *Client) ValidateMachine(ctx context.Context, machineID string) error { return nil } -func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) +func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) From 45873cff90e8092c3aae4d10c8960352c25235e4 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 12:11:48 +0200 Subject: [PATCH 15/28] DeleteWatcher(ctx...) --- cmd/crowdsec-cli/climachine/machines.go | 8 ++++---- pkg/database/machines.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index 09e8d9563a0..d7057b9c7ea 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -432,9 +432,9 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli *cliMachines) delete(machines []string, ignoreMissing bool) error { +func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error { for _, machineID := range machines { - if err := cli.db.DeleteWatcher(machineID); err != nil { + if err := cli.db.DeleteWatcher(ctx, machineID); err != nil { var notFoundErr *database.MachineNotFoundError if ignoreMissing && errors.As(err, ¬FoundErr) { return nil @@ -462,8 +462,8 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command { Aliases: []string{"remove"}, DisableAutoGenTag: true, ValidArgsFunction: cli.validMachineID, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args, ignoreMissing) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) }, } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 1318e199add..2fdad66643d 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -167,11 +167,11 @@ func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error return machines, nil } -func (c *Client) DeleteWatcher(name string) error { +func (c *Client) DeleteWatcher(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Machine. Delete(). Where(machine.MachineIdEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } From e261ae645d6551d1c9918ca6046a7115b0acdffe Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 12:14:09 +0200 Subject: [PATCH 16/28] BulkDeleteWatchers(ctx...) --- cmd/crowdsec-cli/climachine/machines.go | 2 +- pkg/database/machines.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index d7057b9c7ea..7f1fdda56a7 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -514,7 +514,7 @@ func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notVa } } - deleted, err := cli.db.BulkDeleteWatchers(machines) + deleted, err := cli.db.BulkDeleteWatchers(ctx, machines) if err != nil { return fmt.Errorf("unable to prune machines: %w", err) } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 2fdad66643d..2325b858285 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -183,13 +183,13 @@ func (c *Client) DeleteWatcher(ctx context.Context, name string) error { return nil } -func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { +func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) { ids := make([]int, len(machines)) for i, b := range machines { ids[i] = b.ID } - nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, err } From 94f3abe2e5c53de7c3a32f31f903dbabdc457707 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 12:21:07 +0200 Subject: [PATCH 17/28] rest of machines.go --- cmd/crowdsec-cli/climachine/machines.go | 2 +- pkg/apiserver/controllers/v1/heartbeat.go | 4 +++- pkg/apiserver/middlewares/v1/jwt.go | 10 ++++++---- pkg/database/machines.go | 24 +++++++++++------------ 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index 7f1fdda56a7..1fbedcf57fd 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -491,7 +491,7 @@ func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notVa } if !notValidOnly { - if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil { + if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil { machines = append(machines, pending...) } } diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index e1231eaa9ec..799b736ccfe 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -9,7 +9,9 @@ import ( func (c *Controller) HeartBeat(gctx *gin.Context) { machineID, _ := getMachineIDFromContext(gctx) - if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { + ctx := gctx.Request.Context() + + if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 491dc9768cd..b7cccd1aa39 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -177,6 +177,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { auth *authInput ) + ctx := c.Request.Context() + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { auth, err = j.authTLS(c) if err != nil { @@ -200,7 +202,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { } } - err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -210,7 +212,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { clientIP := c.ClientIP() if auth.clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -220,7 +222,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication @@ -233,7 +235,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { + if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) log.Errorf("bad user agent from : %s", clientIP) diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 2325b858285..d8c02825312 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -197,8 +197,8 @@ func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine return nbDeleted, nil } -func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX) +func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error { + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } @@ -206,11 +206,11 @@ func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { return nil } -func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { +func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetUpdatedAt(time.Now().UTC()). SetScenarios(scenarios). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine in database: %w", err) } @@ -218,10 +218,10 @@ func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { return nil } -func (c *Client) UpdateMachineIP(ipAddr string, id int) error { +func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetIpAddress(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine IP in database: %w", err) } @@ -229,10 +229,10 @@ func (c *Client) UpdateMachineIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { +func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetVersion(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine version in database: %w", err) } @@ -240,8 +240,8 @@ func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { return nil } -func (c *Client) IsMachineRegistered(machineID string) (bool, error) { - exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX) +func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) { + exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx) if err != nil { return false, err } @@ -257,11 +257,11 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) { return false, nil } -func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) { +func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) { return c.Ent.Machine.Query().Where( machine.Or( machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)), machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)), ), - ).All(c.CTX) + ).All(ctx) } From b3109defc9859b20fd7a8048e6ce0043a09eeaf9 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 13:04:58 +0200 Subject: [PATCH 18/28] context propagation: pkg/database/bouncers --- cmd/crowdsec-cli/clibouncer/bouncers.go | 26 ++++++++-------- pkg/apiserver/apiserver_test.go | 2 +- pkg/apiserver/controllers/v1/decisions.go | 8 +++-- pkg/apiserver/middlewares/v1/api_key.go | 18 ++++++++---- pkg/apiserver/usage_metrics_test.go | 2 +- pkg/database/bouncers.go | 36 +++++++++++------------ 6 files changed, 51 insertions(+), 41 deletions(-) diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go index 89e91b63911..226fbb7e922 100644 --- a/cmd/crowdsec-cli/clibouncer/bouncers.go +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -208,7 +208,7 @@ func (cli *cliBouncers) newListCmd() *cobra.Command { return cmd } -func (cli *cliBouncers) add(bouncerName string, key string) error { +func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error { var err error keyLength := 32 @@ -220,7 +220,7 @@ func (cli *cliBouncers) add(bouncerName string, key string) error { } } - _, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) if err != nil { return fmt.Errorf("unable to create bouncer: %w", err) } @@ -254,8 +254,8 @@ func (cli *cliBouncers) newAddCmd() *cobra.Command { cscli bouncers add MyBouncerName --key `, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args[0], key) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args[0], key) }, } @@ -304,9 +304,9 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli *cliBouncers) delete(bouncers []string, ignoreMissing bool) error { +func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { for _, bouncerID := range bouncers { - if err := cli.db.DeleteBouncer(bouncerID); err != nil { + if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil { var notFoundErr *database.BouncerNotFoundError if ignoreMissing && errors.As(err, ¬FoundErr) { return nil @@ -332,8 +332,8 @@ func (cli *cliBouncers) newDeleteCmd() *cobra.Command { Aliases: []string{"remove"}, DisableAutoGenTag: true, ValidArgsFunction: cli.validBouncerID, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args, ignoreMissing) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) }, } @@ -343,7 +343,7 @@ func (cli *cliBouncers) newDeleteCmd() *cobra.Command { return cmd } -func (cli *cliBouncers) prune(duration time.Duration, force bool) error { +func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error { if duration < 2*time.Minute { if yes, err := ask.YesNo( "The duration you provided is less than 2 minutes. "+ @@ -355,7 +355,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { } } - bouncers, err := cli.db.QueryBouncersInactiveSince(time.Now().UTC().Add(-duration)) + bouncers, err := cli.db.QueryBouncersInactiveSince(ctx, time.Now().UTC().Add(-duration)) if err != nil { return fmt.Errorf("unable to query bouncers: %w", err) } @@ -378,7 +378,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { } } - deleted, err := cli.db.BulkDeleteBouncers(bouncers) + deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers) if err != nil { return fmt.Errorf("unable to prune bouncers: %w", err) } @@ -403,8 +403,8 @@ func (cli *cliBouncers) newPruneCmd() *cobra.Command { DisableAutoGenTag: true, Example: `cscli bouncers prune -d 45m cscli bouncers prune -d 45m --force`, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, force) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, force) }, } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 9e12d27cb36..2e349a2344b 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -297,7 +297,7 @@ func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { apiKey, err := middlewares.GenerateAPIKey(keyLength) require.NoError(t, err) - _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) require.NoError(t, err) return apiKey diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 54e9b0290cc..139280ab497 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -43,6 +43,8 @@ func (c *Controller) GetDecision(gctx *gin.Context) { data []*ent.Decision ) + ctx := gctx.Request.Context() + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) @@ -73,7 +75,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } if bouncerInfo.LastPull == nil || time.Now().UTC().Sub(*bouncerInfo.LastPull) >= time.Minute { - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(ctx, time.Now().UTC(), bouncerInfo.ID); err != nil { log.Errorf("failed to update bouncer last pull: %v", err) } } @@ -370,6 +372,8 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en func (c *Controller) StreamDecision(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + streamStartTime := time.Now().UTC() bouncerInfo, err := getBouncerFromContext(gctx) @@ -400,7 +404,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { if err == nil { //Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions - if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) } } diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index e822666db0f..d438c9b15a4 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -64,6 +64,8 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + ctx := c.Request.Context() + extractedCN, err := a.TlsAuth.ValidateCert(c) if err != nil { logger.Warn(err) @@ -73,7 +75,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger = logger.WithField("cn", extractedCN) bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - bouncer, err := a.DbClient.SelectBouncerByName(bouncerName) + bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName) // This is likely not the proper way, but isNotFound does not seem to work if err != nil && strings.Contains(err.Error(), "bouncer not found") { @@ -87,7 +89,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil @@ -112,9 +114,11 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + ctx := c.Request.Context() + hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(hashStr) + bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) return nil @@ -132,6 +136,8 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { var bouncer *ent.Bouncer + ctx := c.Request.Context() + clientIP := c.ClientIP() logger := log.WithField("ip", clientIP) @@ -153,7 +159,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { logger = logger.WithField("name", bouncer.Name) if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -166,7 +172,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -182,7 +188,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go index b335738ea16..9c86cc63414 100644 --- a/pkg/apiserver/usage_metrics_test.go +++ b/pkg/apiserver/usage_metrics_test.go @@ -371,7 +371,7 @@ func TestRCMetrics(t *testing.T) { assert.Equal(t, tt.expectedStatusCode, w.Code) assert.Contains(t, w.Body.String(), tt.expectedResponse) - bouncer, _ := dbClient.SelectBouncerByName("test") + bouncer, _ := dbClient.SelectBouncerByName(ctx, "test") metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 6ff308ff786..04ef830ae72 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -41,8 +41,8 @@ func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName strin return nil } -func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) +func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx) if err != nil { return nil, err } @@ -50,8 +50,8 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) +func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx) if err != nil { return nil, err } @@ -68,14 +68,14 @@ func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { return result, nil } -func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { if ent.IsConstraintError(err) { return nil, fmt.Errorf("bouncer %s already exists", name) @@ -87,11 +87,11 @@ func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authTy return bouncer, nil } -func (c *Client) DeleteBouncer(name string) error { +func (c *Client) DeleteBouncer(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Bouncer. Delete(). Where(bouncer.NameEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } @@ -103,13 +103,13 @@ func (c *Client) DeleteBouncer(name string) error { return nil } -func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { +func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) { ids := make([]int, len(bouncers)) for i, b := range bouncers { ids[i] = b.ID } - nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err) } @@ -117,10 +117,10 @@ func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { return nbDeleted, nil } -func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { +func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error { _, err := c.Ent.Bouncer.UpdateOneID(id). SetLastPull(lastPull). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine last pull in database: %w", err) } @@ -128,8 +128,8 @@ func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { return nil } -func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(c.CTX) +func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer ip address in database: %w", err) } @@ -137,8 +137,8 @@ func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(c.CTX) +func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer type and version in database: %w", err) } @@ -146,7 +146,7 @@ func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id in return nil } -func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) { +func (c *Client) QueryBouncersInactiveSince(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) { return c.Ent.Bouncer.Query().Where( // poor man's coalesce bouncer.Or( @@ -156,5 +156,5 @@ func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) bouncer.CreatedAtLT(t), ), ), - ).All(c.CTX) + ).All(ctx) } From 0c4b2fbe3d837eaf1aab6fa00e43149e04df06ac Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 13:11:54 +0200 Subject: [PATCH 19/28] context propagation: pkg/database/lock --- pkg/apiserver/apic.go | 4 ++-- pkg/database/lock.go | 23 ++++++++++++----------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index c79d5f88e3f..7b1b8981171 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -640,7 +640,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { log.Debug("Acquiring lock for pullCAPI") - err = a.dbClient.AcquirePullCAPILock() + err = a.dbClient.AcquirePullCAPILock(ctx) if a.dbClient.IsLocked(err) { log.Info("PullCAPI is already running, skipping") return nil @@ -650,7 +650,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { defer func() { log.Debug("Releasing lock for pullCAPI") - if err := a.dbClient.ReleasePullCAPILock(); err != nil { + if err := a.dbClient.ReleasePullCAPILock(ctx); err != nil { log.Errorf("while releasing lock: %v", err) } }() diff --git a/pkg/database/lock.go b/pkg/database/lock.go index d25b71870f0..e928c10da43 100644 --- a/pkg/database/lock.go +++ b/pkg/database/lock.go @@ -1,6 +1,7 @@ package database import ( + "context" "time" "github.com/pkg/errors" @@ -16,12 +17,12 @@ const ( CapiPullLockName = "pullCAPI" ) -func (c *Client) AcquireLock(name string) error { +func (c *Client) AcquireLock(ctx context.Context, name string) error { log.Debugf("acquiring lock %s", name) _, err := c.Ent.Lock.Create(). SetName(name). SetCreatedAt(types.UtcNow()). - Save(c.CTX) + Save(ctx) if ent.IsConstraintError(err) { return err } @@ -31,21 +32,21 @@ func (c *Client) AcquireLock(name string) error { return nil } -func (c *Client) ReleaseLock(name string) error { +func (c *Client) ReleaseLock(ctx context.Context, name string) error { log.Debugf("releasing lock %s", name) - _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(c.CTX) + _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) } return nil } -func (c *Client) ReleaseLockWithTimeout(name string, timeout int) error { +func (c *Client) ReleaseLockWithTimeout(ctx context.Context, name string, timeout int) error { log.Debugf("releasing lock %s with timeout of %d minutes", name, timeout) _, err := c.Ent.Lock.Delete().Where( lock.NameEQ(name), lock.CreatedAtLT(time.Now().UTC().Add(-time.Duration(timeout)*time.Minute)), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) @@ -57,21 +58,21 @@ func (c *Client) IsLocked(err error) bool { return ent.IsConstraintError(err) } -func (c *Client) AcquirePullCAPILock() error { +func (c *Client) AcquirePullCAPILock(ctx context.Context) error { /*delete orphan "old" lock if present*/ - err := c.ReleaseLockWithTimeout(CapiPullLockName, CAPIPullLockTimeout) + err := c.ReleaseLockWithTimeout(ctx, CapiPullLockName, CAPIPullLockTimeout) if err != nil { log.Errorf("unable to release pullCAPI lock: %s", err) } - return c.AcquireLock(CapiPullLockName) + return c.AcquireLock(ctx, CapiPullLockName) } -func (c *Client) ReleasePullCAPILock() error { +func (c *Client) ReleasePullCAPILock(ctx context.Context) error { log.Debugf("deleting lock %s", CapiPullLockName) _, err := c.Ent.Lock.Delete().Where( lock.NameEQ(CapiPullLockName), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) } From 0b5eaeffd5626d6571134077ca025fae5474e83c Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 13:18:27 +0200 Subject: [PATCH 20/28] QueryAllDecisionsWithFilters(ctx...), QueryExpiredDecisionsWithFilters(ctx...) --- pkg/apiserver/controllers/v1/decisions.go | 13 +++++++++---- pkg/database/decisions.go | 8 ++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 139280ab497..4821fd98871 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -1,6 +1,7 @@ package v1 import ( + "context" "encoding/json" "fmt" "net/http" @@ -136,12 +137,14 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { gctx.JSON(http.StatusOK, deleteDecisionResp) } -func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(map[string][]string) ([]*ent.Decision, error)) error { +func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error { // respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false lastId := 0 + ctx := gctx.Request.Context() + limitStr := fmt.Sprintf("%d", limit) filters["limit"] = []string{limitStr} for { @@ -150,7 +153,7 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(filters) + data, err := dbFunc(ctx, filters) if err != nil { return err } @@ -305,13 +308,15 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en var data []*ent.Decision var err error + ctx := gctx.Request.Context() + ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} ret["deleted"] = []*models.Decision{} if val, ok := gctx.Request.URL.Query()["startup"]; ok { if val[0] == "true" { - data, err = c.DBClient.QueryAllDecisionsWithFilters(filters) + data, err = c.DBClient.QueryAllDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -322,7 +327,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en ret["new"] = FormatDecisions(data) // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters) + data, err = c.DBClient.QueryExpiredDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 5fd4757c883..a574e86ab76 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -121,7 +121,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] return query, nil } -func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -138,7 +138,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") @@ -147,7 +147,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e return data, nil } -func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) @@ -165,7 +165,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters") } - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions") From f071d609dd64c09142269bb844a77ecb86a0a88b Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 13:26:55 +0200 Subject: [PATCH 21/28] more Query...Decision...(ctx..) --- pkg/apiserver/controllers/v1/decisions.go | 12 +++++++----- pkg/database/decisions.go | 12 ++++++------ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 4821fd98871..979fa330c4b 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -53,7 +53,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { return } - data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) + data, err = c.DBClient.QueryDecisionWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -191,12 +191,14 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun return nil } -func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(*time.Time, map[string][]string) ([]*ent.Decision, error)) error { +func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error { //respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false lastId := 0 + ctx := gctx.Request.Context() + limitStr := fmt.Sprintf("%d", limit) filters["limit"] = []string{limitStr} for { @@ -205,7 +207,7 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(lastPull, filters) + data, err := dbFunc(ctx, lastPull, filters) if err != nil { return err } @@ -344,7 +346,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en } // getting new decisions - data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters) + data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(ctx, bouncerInfo.LastPull, filters) if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -360,7 +362,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en } // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(&since, filters) // do we want to give exactly lastPull time ? + data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(ctx, &since, filters) // do we want to give exactly lastPull time ? if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index a574e86ab76..ebf6cffdec2 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -196,7 +196,7 @@ func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*Decisions return r, nil } -func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) { var data []*ent.Decision var err error @@ -218,7 +218,7 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec decision.FieldValue, decision.FieldScope, decision.FieldOrigin, - ).Scan(c.CTX, &data) + ).Scan(ctx, &data) if err != nil { c.Log.Warningf("QueryDecisionWithFilter : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed") @@ -255,7 +255,7 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) { ) } -func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) @@ -277,7 +277,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") @@ -286,7 +286,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters return data, nil } -func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -308,7 +308,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) From 19f4bbf1bdad5efc0951571499e03a97f01264a8 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:00:27 +0200 Subject: [PATCH 22/28] alerts wip --- pkg/apiserver/papi_cmd.go | 4 +++- pkg/database/alerts.go | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 943eb4139de..04fd55b55c2 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -95,6 +95,8 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { } func AlertCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "add": data, err := json.Marshal(message.Data) @@ -153,7 +155,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } // use a different method: alert and/or decision might already be partially present in the database - _, err = p.DBClient.CreateOrUpdateAlert("", alert) + _, err = p.DBClient.CreateOrUpdateAlert(ctx, "", alert) if err != nil { log.Errorf("Failed to create alerts in DB: %s", err) } else { diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index d2760a209f9..78fd8be7964 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -35,12 +35,12 @@ const ( // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them -func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { +func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) { if alertItem.UUID == "" { return "", errors.New("alert UUID is empty") } - alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX) + alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx) if err != nil && !ent.IsNotFound(err) { return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err) @@ -165,7 +165,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX) + decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { return "", fmt.Errorf("creating alert decisions: %w", err) } @@ -178,7 +178,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX) + err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx) if err != nil { return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) } From bc12438fe57e68822f549ea6ae586b6c6ac5320d Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:04:21 +0200 Subject: [PATCH 23/28] alerts wip/2 --- pkg/apiserver/apic.go | 4 +++- pkg/database/alerts.go | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 7b1b8981171..e79224c417e 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -766,6 +766,8 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis } func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { + ctx := context.TODO() + for _, alert := range alertsFromCapi { setAlertScenario(alert, addCounters, deleteCounters) log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) @@ -774,7 +776,7 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") } - alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) + alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert) if err != nil { return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) } diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 78fd8be7964..bff50690007 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -191,7 +191,7 @@ func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, aler // it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions: // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin -func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { +func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) { if alertItem == nil { return 0, 0, 0, errors.New("nil alert") } @@ -244,7 +244,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetScenarioHash(*alertItem.ScenarioHash). SetRemediation(true) // it's from CAPI, we always have decisions - alertRef, err := alertB.Save(c.CTX) + alertRef, err := alertB.Save(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err) } @@ -253,7 +253,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, 0, 0, nil } - txClient, err := c.Ent.Tx(c.CTX) + txClient, err := c.Ent.Tx(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } @@ -347,7 +347,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in decision.OriginEQ(DecOrigin), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), decision.ValueIn(deleteChunk...), - )).Exec(c.CTX) + )).Exec(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -363,7 +363,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX) + insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { From 6fe957ce4e603bc2e58420677b99a92f6f782cba Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:06:34 +0200 Subject: [PATCH 24/28] alerts wip/3 --- pkg/database/alerts.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index bff50690007..ece5e31fade 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -391,7 +391,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models return alertRef.ID, inserted, deleted, nil } -func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { +func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { decisionCreate := []*ent.DecisionCreate{} for _, decisionItem := range decisions { @@ -436,7 +436,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return nil, nil } - ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX) + ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx) if err != nil { return nil, err } @@ -444,7 +444,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return ret, nil } -func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { +func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { alertBuilders := []*ent.AlertCreate{} alertDecisions := [][]*ent.Decision{} @@ -540,7 +540,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } - events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) + events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) } @@ -568,7 +568,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ SetValue(value) } - metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) + metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx) if err != nil { c.Log.Warningf("error creating alert meta: %s", err) } @@ -578,7 +578,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk) + decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk) if err != nil { return nil, fmt.Errorf("creating alert decisions: %w", err) } @@ -636,7 +636,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return nil, nil } - alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX) + alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) } @@ -653,7 +653,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for retry < maxLockRetries { // so much for the happy path... but sqlite3 errors work differently - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) + _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx) if err == nil { break } @@ -708,7 +708,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str alertIDs := []string{} for _, alertChunk := range alertChunks { - ids, err := c.createAlertChunk(machineID, owner, alertChunk) + ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -717,7 +717,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str } if owner != nil { - err = owner.Update().SetLastPush(time.Now().UTC()).Exec(c.CTX) + err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } From 974ac598aa635c2198aeacc1e740696dec054787 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:10:15 +0200 Subject: [PATCH 25/28] alerts wip/4 --- pkg/apiserver/controllers/v1/alerts.go | 4 +++- pkg/database/alerts.go | 20 ++++++++++---------- pkg/database/flush.go | 4 ++-- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 84b3094865c..fd81614cec2 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -261,7 +261,9 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { // FindAlerts: returns alerts from the database based on the specified filter func (c *Controller) FindAlerts(gctx *gin.Context) { - result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + result, err := c.DBClient.QueryAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index ece5e31fade..aa5f5d044b0 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -996,11 +996,11 @@ func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string] return counts, nil } -func (c *Client) TotalAlerts() (int, error) { - return c.Ent.Alert.Query().Count(c.CTX) +func (c *Client) TotalAlerts(ctx context.Context) (int, error) { + return c.Ent.Alert.Query().Count(ctx) } -func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { +func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) { sort := "DESC" // we sort by desc by default if val, ok := filter["sort"]; ok { @@ -1047,7 +1047,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, WithOwner() if limit == 0 { - limit, err = alerts.Count(c.CTX) + limit, err = alerts.Count(ctx) if err != nil { return nil, fmt.Errorf("unable to count nb alerts: %w", err) } @@ -1059,7 +1059,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) } - result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) + result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) } @@ -1088,35 +1088,35 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, return ret, nil } -func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { +func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) { idList := make([]int, 0) for _, alert := range alertItems { idList = append(idList, alert.ID) } _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events") } _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta") } _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions") } deleted, err := c.Ent.Alert.Delete(). - Where(alert.IDIn(idList...)).Exec(c.CTX) + Where(alert.IDIn(idList...)).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch") diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 46c8edfa308..2ec88ce34b8 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -239,7 +239,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e c.FlushOrphans(ctx) c.Log.Debug("Done flushing orphan alerts") - totalAlerts, err = c.TotalAlerts() + totalAlerts, err = c.TotalAlerts(ctx) if err != nil { c.Log.Warningf("FlushAlerts (max items count): %s", err) return fmt.Errorf("unable to get alerts count: %w", err) @@ -268,7 +268,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e // This gives us the oldest alert that we want to keep // We then delete all the alerts with an id lower than this one // We can do this because the id is auto-increment, and the database won't reuse the same id twice - lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ + lastAlert, err := c.QueryAlertWithFilter(ctx, map[string][]string{ "sort": {"DESC"}, "limit": {"1"}, // we do not care about fetching the edges, we just want the id From af621ab98aea4436969756a75ac228393defe69e Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:14:05 +0200 Subject: [PATCH 26/28] alerts wip/5 --- pkg/apiserver/controllers/v1/alerts.go | 11 ++++++++--- pkg/database/alerts.go | 24 ++++++++++++------------ pkg/database/flush.go | 2 +- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index fd81614cec2..19d37879f22 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -281,6 +281,7 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { // FindAlertByID returns the alert associated with the ID func (c *Controller) FindAlertByID(gctx *gin.Context) { + ctx := gctx.Request.Context() alertIDStr := gctx.Param("alert_id") alertID, err := strconv.Atoi(alertIDStr) @@ -289,7 +290,7 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { return } - result, err := c.DBClient.GetAlertByID(alertID) + result, err := c.DBClient.GetAlertByID(ctx, alertID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -309,6 +310,8 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { func (c *Controller) DeleteAlertByID(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) @@ -323,7 +326,7 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { return } - err = c.DBClient.DeleteAlertByID(decisionID) + err = c.DBClient.DeleteAlertByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -336,13 +339,15 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { // DeleteAlerts deletes alerts from the database based on the specified filter func (c *Controller) DeleteAlerts(gctx *gin.Context) { + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index aa5f5d044b0..406b64221d7 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -1127,10 +1127,10 @@ func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Al return deleted, nil } -func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { +func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error { // delete the associated events _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID) @@ -1138,7 +1138,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated meta _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID) @@ -1146,14 +1146,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated decisions _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID) } // delete the alert - err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX) + err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID) @@ -1162,26 +1162,26 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { return nil } -func (c *Client) DeleteAlertByID(id int) error { - alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX) +func (c *Client) DeleteAlertByID(ctx context.Context, id int) error { + alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx) if err != nil { return err } - return c.DeleteAlertGraph(alertItem) + return c.DeleteAlertGraph(ctx, alertItem) } -func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) { +func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return 0, err } - return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) + return c.Ent.Alert.Delete().Where(preds...).Exec(ctx) } -func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { - alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) +func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) { + alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx) if err != nil { /*record not found, 404*/ if ent.IsNotFound(err) { diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 2ec88ce34b8..8f646ddc961 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -252,7 +252,7 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e "created_before": {MaxAge}, } - nbDeleted, err := c.DeleteAlertWithFilter(filter) + nbDeleted, err := c.DeleteAlertWithFilter(ctx, filter) if err != nil { c.Log.Warningf("FlushAlerts (max age): %s", err) return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) From 7cac084909472a7cf6636cf94b79e47cb22ae3f9 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:26:03 +0200 Subject: [PATCH 27/28] rest of decisions --- pkg/apiserver/apic.go | 7 ++- pkg/apiserver/controllers/v1/decisions.go | 8 ++- pkg/apiserver/papi_cmd.go | 8 ++- pkg/database/decisions.go | 59 +++++++++-------------- pkg/exprhelpers/helpers.go | 17 +++++-- 5 files changed, 53 insertions(+), 46 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index e79224c417e..ab1cdfea917 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -426,6 +426,7 @@ func (a *apic) CAPIPullIsOld() (bool, error) { } func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) { + ctx := context.TODO() nbDeleted := 0 for _, decision := range deletedDecisions { @@ -438,7 +439,7 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet filter["scopes"] = []string{*decision.Scope} } - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return 0, fmt.Errorf("expiring decisions error: %w", err) } @@ -458,6 +459,8 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { var nbDeleted int + ctx := context.TODO() + for _, decisions := range deletedDecisions { scope := decisions.Scope @@ -470,7 +473,7 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi filter["scopes"] = []string{*scope} } - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return 0, fmt.Errorf("expiring decisions error: %w", err) } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 979fa330c4b..09c40450642 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -94,7 +94,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { return } - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(decisionID) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) @@ -116,7 +118,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 04fd55b55c2..78f5dc9b0fe 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -43,6 +43,8 @@ type listUnsubscribe struct { } func DecisionCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "delete": data, err := json.Marshal(message.Data) @@ -65,7 +67,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err) } @@ -170,6 +172,8 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } func ManagementCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + if sync { p.Logger.Infof("Ignoring management command from PAPI in sync mode") return nil @@ -197,7 +201,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { filter["origin"] = []string{types.ListOrigin} filter["scenario"] = []string{unsubscribeMsg.Name} - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err) } diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index ebf6cffdec2..8547990c25f 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -317,20 +317,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *t return data, nil } -func (c *Client) DeleteDecisionById(decisionID int) ([]*ent.Decision, error) { - toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) - if err != nil { - c.Log.Warningf("DeleteDecisionById : %s", err) - return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) - } - - count, err := c.DeleteDecisions(toDelete) - c.Log.Debugf("deleted %d decisions", count) - - return toDelete, err -} - -func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -433,13 +420,13 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - toDelete, err := decisions.All(c.CTX) + toDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("DeleteDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } - count, err := c.DeleteDecisions(toDelete) + count, err := c.DeleteDecisions(ctx, toDelete) if err != nil { c.Log.Warningf("While deleting decisions : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") @@ -449,7 +436,7 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, } // ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items -func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -558,13 +545,13 @@ func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - DecisionsToDelete, err := decisions.All(c.CTX) + DecisionsToDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("ExpireDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter") } - count, err := c.ExpireDecisions(DecisionsToDelete) + count, err := c.ExpireDecisions(ctx, DecisionsToDelete) if err != nil { return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err) } @@ -583,13 +570,13 @@ func decisionIDs(decisions []*ent.Decision) []int { // ExpireDecisions sets the expiration of a list of decisions to now() // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) <= decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Update().Where( decision.IDIn(ids...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) + ).SetUntil(time.Now().UTC()).Save(ctx) if err != nil { return 0, fmt.Errorf("expire decisions with provided filter: %w", err) } @@ -602,7 +589,7 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { total := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.ExpireDecisions(chunk) + rows, err := c.ExpireDecisions(ctx, chunk) if err != nil { return total, err } @@ -615,13 +602,13 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { // DeleteDecisions removes a list of decisions from the database // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) < decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Delete().Where( decision.IDIn(ids...), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err) } @@ -634,7 +621,7 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { tot := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.DeleteDecisions(chunk) + rows, err := c.DeleteDecisions(ctx, chunk) if err != nil { return tot, err } @@ -646,8 +633,8 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { } // ExpireDecision set the expiration of a decision to now() -func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error) { - toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) +func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) { + toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx) // XXX: do we want 500 or 404 here? if err != nil || len(toUpdate) == 0 { @@ -659,12 +646,12 @@ func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error return 0, nil, ItemNotFound } - count, err := c.ExpireDecisions(toUpdate) + count, err := c.ExpireDecisions(ctx, toUpdate) return count, toUpdate, err } -func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -682,7 +669,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -690,7 +677,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return count, nil } -func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -710,7 +697,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, fmt.Errorf("fail to count decisions: %w", err) } @@ -718,7 +705,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) return count, nil } -func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.Duration, error) { +func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -740,7 +727,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D decisions = decisions.Order(ent.Desc(decision.FieldUntil)) - decision, err := decisions.First(c.CTX) + decision, err := decisions.First(ctx) if err != nil && !ent.IsNotFound(err) { return 0, fmt.Errorf("fail to get decision: %w", err) } @@ -752,7 +739,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D return decision.Until.Sub(time.Now().UTC()), nil } -func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { +func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) { ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) @@ -768,7 +755,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err := decisions.Count(c.CTX) + count, err := decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index 2ca7d0be79a..6b7eb0840e9 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -2,6 +2,7 @@ package exprhelpers import ( "bufio" + "context" "encoding/base64" "errors" "fmt" @@ -592,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) { return 0, nil } - count, err := dbClient.CountDecisionsByValue(value) + + ctx := context.TODO() + + count, err := dbClient.CountDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -613,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) { log.Errorf("Failed to parse since parameter '%s' : %s", since, err) return 0, nil } + + ctx := context.TODO() sinceTime := time.Now().UTC().Add(-sinceDuration) - count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) + + count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -628,7 +635,8 @@ func GetActiveDecisionsCount(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsCount()") return 0, nil } - count, err := dbClient.CountActiveDecisionsByValue(value) + ctx := context.TODO() + count, err := dbClient.CountActiveDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions count from value '%s'", value) return 0, err @@ -642,7 +650,8 @@ func GetActiveDecisionsTimeLeft(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsTimeLeft()") return 0, nil } - timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(value) + ctx := context.TODO() + timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions time left from value '%s'", value) return 0, err From b757a5bdb26c947eddaec19096eddc7a848f883d Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:33:06 +0200 Subject: [PATCH 28/28] drop CTX from dbclient --- pkg/apiserver/apic.go | 6 +++--- pkg/apiserver/apic_test.go | 10 ++++++---- pkg/apiserver/middlewares/v1/jwt.go | 4 +++- pkg/database/database.go | 2 -- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index ab1cdfea917..9b56fef6549 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -406,13 +406,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } } -func (a *apic) CAPIPullIsOld() (bool, error) { +func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) { /*only pull community blocklist if it's older than 1h30 */ alerts := a.dbClient.Ent.Alert.Query() alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + count, err := alerts.Count(ctx) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } @@ -634,7 +634,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { } if !forcePull { - if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { + if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil { return err } else if !lastPullIsOld { return nil diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 97943b495e5..3bb158acf35 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -113,7 +113,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) { func TestAPICCAPIPullIsOld(t *testing.T) { api := getAPIC(t) - isOld, err := api.CAPIPullIsOld() + ctx := context.Background() + + isOld, err := api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.True(t, isOld) @@ -124,7 +126,7 @@ func TestAPICCAPIPullIsOld(t *testing.T) { SetScope("Country"). SetValue("Blah"). SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) + SaveX(ctx) api.dbClient.Ent.Alert.Create(). SetCreatedAt(time.Now()). @@ -132,9 +134,9 @@ func TestAPICCAPIPullIsOld(t *testing.T) { AddDecisions( decision, ). - SaveX(context.Background()) + SaveX(ctx) - isOld, err = api.CAPIPullIsOld() + isOld, err = api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.False(t, isOld) diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index b7cccd1aa39..5877c945106 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -129,6 +129,8 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { err error ) + ctx := c.Request.Context() + ret := authInput{} if err = c.ShouldBindJSON(&loginInput); err != nil { @@ -145,7 +147,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if err != nil { log.Infof("Error machine login for %s : %+v ", ret.machineID, err) return nil, err diff --git a/pkg/database/database.go b/pkg/database/database.go index e513459199f..bb41dd3b645 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -21,7 +21,6 @@ import ( type Client struct { Ent *ent.Client - CTX context.Context Log *log.Logger CanFlush bool Type string @@ -106,7 +105,6 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro return &Client{ Ent: client, - CTX: ctx, Log: clog, CanFlush: true, Type: config.Type,