Skip to content

Commit

Permalink
context propagation: papi, loki (crowdsecurity#3308)
Browse files Browse the repository at this point in the history
* context propagation: AuthenticatedLAPIClient()

* context propagation: papi

* context propagation: loki
  • Loading branch information
mmetc authored Nov 15, 2024
1 parent b96a7a5 commit a4497da
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clipapi/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client

t.Go(papi.SyncDecisions)

err = papi.PullOnce(time.Time{}, true)
err = papi.PullOnce(ctx, time.Time{}, true)
if err != nil {
return fmt.Errorf("unable to sync decisions: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/crowdsec/crowdsec.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H
})
bucketWg.Wait()

apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub)
apiClient, err := AuthenticatedLAPIClient(context.TODO(), *cConfig.API.Client.Credentials, hub)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/crowdsec/lapiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
)

func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) {
func AuthenticatedLAPIClient(ctx context.Context, credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) {
apiURL, err := url.Parse(credentials.URL)
if err != nil {
return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err)
Expand Down Expand Up @@ -44,7 +44,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.
return nil, fmt.Errorf("new client api: %w", err)
}

authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{
MachineID: &credentials.Login,
Password: &password,
Scenarios: itemsForAPI,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu
case <-lc.t.Dying():
return lc.t.Err()
case <-ticker.C:
resp, err := lc.Get(uri)
resp, err := lc.Get(ctx, uri)
if err != nil {
if ok := lc.shouldRetry(); !ok {
return fmt.Errorf("error querying range: %w", err)
Expand Down Expand Up @@ -215,7 +215,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error {
return lc.t.Err()
case <-tick.C:
lc.Logger.Debug("Checking if Loki is ready")
resp, err := lc.Get(url)
resp, err := lc.Get(ctx, url)
if err != nil {
lc.Logger.Warnf("Error checking if Loki is ready: %s", err)
continue
Expand Down Expand Up @@ -300,8 +300,8 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ
}

// Create a wrapper for http.Get to be able to set headers and auth
func (lc *LokiClient) Get(url string) (*http.Response, error) {
request, err := http.NewRequest(http.MethodGet, url, nil)
func (lc *LokiClient) Get(ctx context.Context, url string) (*http.Response, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/apiserver/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ func reverse(s []longpollclient.Event) []longpollclient.Event {
return a
}

func (p *Papi) PullOnce(since time.Time, sync bool) error {
events, err := p.Client.PullOnce(since)
func (p *Papi) PullOnce(ctx context.Context, since time.Time, sync bool) error {
events, err := p.Client.PullOnce(ctx, since)
if err != nil {
return err
}
Expand Down Expand Up @@ -261,7 +261,7 @@ func (p *Papi) Pull(ctx context.Context) error {

p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)

for event := range p.Client.Start(lastTimestamp) {
for event := range p.Client.Start(ctx, lastTimestamp) {
logger := p.Logger.WithField("request-id", event.RequestId)
// update last timestamp in database
newTime := time.Now().UTC()
Expand Down
21 changes: 11 additions & 10 deletions pkg/longpollclient/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package longpollclient

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -50,7 +51,7 @@ var errUnauthorized = errors.New("user is not authorized to use PAPI")

const timeoutMessage = "no events before timeout"

func (c *LongPollClient) doQuery() (*http.Response, error) {
func (c *LongPollClient) doQuery(ctx context.Context) (*http.Response, error) {
logger := c.logger.WithField("method", "doQuery")
query := c.url.Query()
query.Set("since_time", fmt.Sprintf("%d", c.since))
Expand All @@ -59,7 +60,7 @@ func (c *LongPollClient) doQuery() (*http.Response, error) {

logger.Debugf("Query parameters: %s", c.url.RawQuery)

req, err := http.NewRequest(http.MethodGet, c.url.String(), nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url.String(), nil)
if err != nil {
logger.Errorf("failed to create request: %s", err)
return nil, err
Expand All @@ -73,10 +74,10 @@ func (c *LongPollClient) doQuery() (*http.Response, error) {
return resp, nil
}

func (c *LongPollClient) poll() error {
func (c *LongPollClient) poll(ctx context.Context) error {
logger := c.logger.WithField("method", "poll")

resp, err := c.doQuery()
resp, err := c.doQuery(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -146,15 +147,15 @@ func (c *LongPollClient) poll() error {
}
}

func (c *LongPollClient) pollEvents() error {
func (c *LongPollClient) pollEvents(ctx context.Context) error {
for {
select {
case <-c.t.Dying():
c.logger.Debug("dying")
return nil
default:
c.logger.Debug("Polling PAPI")
err := c.poll()
err := c.poll(ctx)
if err != nil {
c.logger.Errorf("failed to poll: %s", err)
if errors.Is(err, errUnauthorized) {
Expand All @@ -168,12 +169,12 @@ func (c *LongPollClient) pollEvents() error {
}
}

func (c *LongPollClient) Start(since time.Time) chan Event {
func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event {
c.logger.Infof("starting polling client")
c.c = make(chan Event)
c.since = since.Unix() * 1000
c.timeout = "45"
c.t.Go(c.pollEvents)
c.t.Go(func() error {return c.pollEvents(ctx)})
return c.c
}

Expand All @@ -182,11 +183,11 @@ func (c *LongPollClient) Stop() error {
return nil
}

func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) {
func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event, error) {
c.logger.Debug("Pulling PAPI once")
c.since = since.Unix() * 1000
c.timeout = "1"
resp, err := c.doQuery()
resp, err := c.doQuery(ctx)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit a4497da

Please sign in to comment.