diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index 461215c3a39..7ac2455d28f 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -136,7 +136,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client t.Go(papi.SyncDecisions) - err = papi.PullOnce(time.Time{}, true) + err = papi.PullOnce(ctx, time.Time{}, true) if err != nil { return fmt.Errorf("unable to sync decisions: %w", err) } diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index c44d71d2093..db93992605d 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -116,7 +116,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H }) bucketWg.Wait() - apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub) + apiClient, err := AuthenticatedLAPIClient(context.TODO(), *cConfig.API.Client.Credentials, hub) if err != nil { return err } diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go index eed517f9df9..6656ba6b4c2 100644 --- a/cmd/crowdsec/lapiclient.go +++ b/cmd/crowdsec/lapiclient.go @@ -14,7 +14,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { +func AuthenticatedLAPIClient(ctx context.Context, credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { apiURL, err := url.Parse(credentials.URL) if err != nil { return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err) @@ -44,7 +44,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub. return nil, fmt.Errorf("new client api: %w", err) } - authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &credentials.Login, Password: &password, Scenarios: itemsForAPI, diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go index 846e833abea..fce199c5708 100644 --- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -119,7 +119,7 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu case <-lc.t.Dying(): return lc.t.Err() case <-ticker.C: - resp, err := lc.Get(uri) + resp, err := lc.Get(ctx, uri) if err != nil { if ok := lc.shouldRetry(); !ok { return fmt.Errorf("error querying range: %w", err) @@ -215,7 +215,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error { return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if Loki is ready") - resp, err := lc.Get(url) + resp, err := lc.Get(ctx, url) if err != nil { lc.Logger.Warnf("Error checking if Loki is ready: %s", err) continue @@ -300,8 +300,8 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ } // Create a wrapper for http.Get to be able to set headers and auth -func (lc *LokiClient) Get(url string) (*http.Response, error) { - request, err := http.NewRequest(http.MethodGet, url, nil) +func (lc *LokiClient) Get(ctx context.Context, url string) (*http.Response, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 7dd6b346aa9..83ba13843b9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -205,8 +205,8 @@ func reverse(s []longpollclient.Event) []longpollclient.Event { return a } -func (p *Papi) PullOnce(since time.Time, sync bool) error { - events, err := p.Client.PullOnce(since) +func (p *Papi) PullOnce(ctx context.Context, since time.Time, sync bool) error { + events, err := p.Client.PullOnce(ctx, since) if err != nil { return err } @@ -261,7 +261,7 @@ func (p *Papi) Pull(ctx context.Context) error { p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp) - for event := range p.Client.Start(lastTimestamp) { + for event := range p.Client.Start(ctx, lastTimestamp) { logger := p.Logger.WithField("request-id", event.RequestId) // update last timestamp in database newTime := time.Now().UTC() diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index 5a7af0bfa63..5c395185b20 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -1,6 +1,7 @@ package longpollclient import ( + "context" "encoding/json" "errors" "fmt" @@ -50,7 +51,7 @@ var errUnauthorized = errors.New("user is not authorized to use PAPI") const timeoutMessage = "no events before timeout" -func (c *LongPollClient) doQuery() (*http.Response, error) { +func (c *LongPollClient) doQuery(ctx context.Context) (*http.Response, error) { logger := c.logger.WithField("method", "doQuery") query := c.url.Query() query.Set("since_time", fmt.Sprintf("%d", c.since)) @@ -59,7 +60,7 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { logger.Debugf("Query parameters: %s", c.url.RawQuery) - req, err := http.NewRequest(http.MethodGet, c.url.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url.String(), nil) if err != nil { logger.Errorf("failed to create request: %s", err) return nil, err @@ -73,10 +74,10 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { return resp, nil } -func (c *LongPollClient) poll() error { +func (c *LongPollClient) poll(ctx context.Context) error { logger := c.logger.WithField("method", "poll") - resp, err := c.doQuery() + resp, err := c.doQuery(ctx) if err != nil { return err } @@ -146,7 +147,7 @@ func (c *LongPollClient) poll() error { } } -func (c *LongPollClient) pollEvents() error { +func (c *LongPollClient) pollEvents(ctx context.Context) error { for { select { case <-c.t.Dying(): @@ -154,7 +155,7 @@ func (c *LongPollClient) pollEvents() error { return nil default: c.logger.Debug("Polling PAPI") - err := c.poll() + err := c.poll(ctx) if err != nil { c.logger.Errorf("failed to poll: %s", err) if errors.Is(err, errUnauthorized) { @@ -168,12 +169,12 @@ func (c *LongPollClient) pollEvents() error { } } -func (c *LongPollClient) Start(since time.Time) chan Event { +func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event { c.logger.Infof("starting polling client") c.c = make(chan Event) c.since = since.Unix() * 1000 c.timeout = "45" - c.t.Go(c.pollEvents) + c.t.Go(func() error {return c.pollEvents(ctx)}) return c.c } @@ -182,11 +183,11 @@ func (c *LongPollClient) Stop() error { return nil } -func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { +func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event, error) { c.logger.Debug("Pulling PAPI once") c.since = since.Unix() * 1000 c.timeout = "1" - resp, err := c.doQuery() + resp, err := c.doQuery(ctx) if err != nil { return nil, err }