Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

context propagation: pkg/database/config #127

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
6 changes: 4 additions & 2 deletions cmd/crowdsec-cli/clialert/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,15 +575,17 @@ func (cli *cliAlerts) newFlushCmd() *cobra.Command {
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error {
cfg := cli.cfg()
ctx := cmd.Context()

if err := require.LAPI(cfg); err != nil {
return err
}
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
db, err := require.DBClient(ctx, cfg.DbConfig)
if err != nil {
return err
}
log.Info("Flushing alerts. !! This may take a long time !!")
err = db.FlushAlerts(maxAge, maxItems)
err = db.FlushAlerts(ctx, maxAge, maxItems)
if err != nil {
return fmt.Errorf("unable to flush alerts: %w", err)
}
Expand Down
14 changes: 8 additions & 6 deletions cmd/crowdsec-cli/clibouncer/bouncers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clibouncer

import (
"context"
"encoding/csv"
"encoding/json"
"errors"
Expand Down Expand Up @@ -159,11 +160,11 @@ func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error {
return nil
}

func (cli *cliBouncers) List(out io.Writer, db *database.Client) error {
func (cli *cliBouncers) List(ctx context.Context, out io.Writer, db *database.Client) error {
// XXX: must use the provided db object, the one in the struct might be nil
// (calling List directly skips the PersistentPreRunE)

bouncers, err := db.ListBouncers()
bouncers, err := db.ListBouncers(ctx)
if err != nil {
return fmt.Errorf("unable to list bouncers: %w", err)
}
Expand Down Expand Up @@ -199,8 +200,8 @@ func (cli *cliBouncers) newListCmd() *cobra.Command {
Example: `cscli bouncers list`,
Args: cobra.ExactArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.List(color.Output, cli.db)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.List(cmd.Context(), color.Output, cli.db)
},
}

Expand Down Expand Up @@ -271,6 +272,7 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp
var err error

cfg := cli.cfg()
ctx := cmd.Context()

// need to load config and db because PersistentPreRunE is not called for completions

Expand All @@ -279,13 +281,13 @@ func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComp
return nil, cobra.ShellCompDirectiveNoFileComp
}

cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
cli.db, err = require.DBClient(ctx, cfg.DbConfig)
if err != nil {
cobra.CompError("unable to list bouncers " + err.Error())
return nil, cobra.ShellCompDirectiveNoFileComp
}

bouncers, err := cli.db.ListBouncers()
bouncers, err := cli.db.ListBouncers(ctx)
if err != nil {
cobra.CompError("unable to list bouncers " + err.Error())
return nil, cobra.ShellCompDirectiveNoFileComp
Expand Down
14 changes: 8 additions & 6 deletions cmd/crowdsec-cli/clipapi/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command {
func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error {
cfg := cli.cfg()

apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
if err != nil {
return fmt.Errorf("unable to initialize API client: %w", err)
}
Expand All @@ -74,7 +74,7 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie
return fmt.Errorf("unable to get PAPI permissions: %w", err)
}

lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey)
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)
if err != nil {
lastTimestampStr = ptr.Of("never")
}
Expand Down Expand Up @@ -118,11 +118,11 @@ func (cli *cliPapi) newStatusCmd() *cobra.Command {
return cmd
}

func (cli *cliPapi) sync(out io.Writer, db *database.Client) error {
func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client) error {
cfg := cli.cfg()
t := tomb.Tomb{}

apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
if err != nil {
return fmt.Errorf("unable to initialize API client: %w", err)
}
Expand Down Expand Up @@ -159,12 +159,14 @@ func (cli *cliPapi) newSyncCmd() *cobra.Command {
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error {
cfg := cli.cfg()
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
ctx := cmd.Context()

db, err := require.DBClient(ctx, cfg.DbConfig)
if err != nil {
return err
}

return cli.sync(color.Output, db)
return cli.sync(ctx, color.Output, db)
},
}

Expand Down
6 changes: 3 additions & 3 deletions cmd/crowdsec-cli/clisupport/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub) error {
return nil
}

func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error {
func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error {
log.Info("Collecting bouncers")

if db == nil {
Expand All @@ -199,7 +199,7 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error {
out := new(bytes.Buffer)
cb := clibouncer.New(cli.cfg)

if err := cb.List(out, db); err != nil {
if err := cb.List(ctx, out, db); err != nil {
return err
}

Expand Down Expand Up @@ -525,7 +525,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
log.Warnf("could not collect hub information: %s", err)
}

if err = cli.dumpBouncers(zipWriter, db); err != nil {
if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil {
log.Warnf("could not collect bouncers information: %s", err)
}

Expand Down
16 changes: 11 additions & 5 deletions pkg/acquisition/modules/loki/loki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestConfigureDSN(t *testing.T) {
}
}

func feedLoki(logger *log.Entry, n int, title string) error {
func feedLoki(ctx context.Context, logger *log.Entry, n int, title string) error {
streams := LogStreams{
Streams: []LogStream{
{
Expand All @@ -286,7 +286,7 @@ func feedLoki(logger *log.Entry, n int, title string) error {
return err
}

req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff))
if err != nil {
return err
}
Expand Down Expand Up @@ -349,7 +349,9 @@ since: 1h
t.Fatalf("Unexpected error : %s", err)
}

err = feedLoki(subLogger, 20, title)
ctx := context.Background()

err = feedLoki(ctx, subLogger, 20, title)
if err != nil {
t.Fatalf("Unexpected error : %s", err)
}
Expand Down Expand Up @@ -421,6 +423,8 @@ query: >
},
}

ctx := context.Background()

for _, ts := range tests {
t.Run(ts.name, func(t *testing.T) {
logger := log.New()
Expand Down Expand Up @@ -472,7 +476,7 @@ query: >
}
})

err = feedLoki(subLogger, ts.expectedLines, title)
err = feedLoki(ctx, subLogger, ts.expectedLines, title)
if err != nil {
t.Fatalf("Unexpected error : %s", err)
}
Expand Down Expand Up @@ -525,7 +529,9 @@ query: >

time.Sleep(time.Second * 2)

err = feedLoki(subLogger, 1, title)
ctx := context.Background()

err = feedLoki(ctx, subLogger, 1, title)
if err != nil {
t.Fatalf("Unexpected error : %s", err)
}
Expand Down
41 changes: 22 additions & 19 deletions pkg/acquisition/modules/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ type S3Source struct {
readerChan chan S3Object
t *tomb.Tomb
out chan types.Event
ctx aws.Context
cancel context.CancelFunc
}

Expand Down Expand Up @@ -184,7 +183,7 @@ func (s *S3Source) newSQSClient() error {
return nil
}

func (s *S3Source) readManager() {
func (s *S3Source) readManager(ctx context.Context) {
logger := s.logger.WithField("method", "readManager")
for {
select {
Expand All @@ -194,21 +193,21 @@ func (s *S3Source) readManager() {
return
case s3Object := <-s.readerChan:
logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key)
err := s.readFile(s3Object.Bucket, s3Object.Key)
err := s.readFile(ctx, s3Object.Bucket, s3Object.Key)
if err != nil {
logger.Errorf("Error while reading file: %s", err)
}
}
}
}

func (s *S3Source) getBucketContent() ([]*s3.Object, error) {
func (s *S3Source) getBucketContent(ctx context.Context) ([]*s3.Object, error) {
logger := s.logger.WithField("method", "getBucketContent")
logger.Debugf("Getting bucket content for %s", s.Config.BucketName)
bucketObjects := make([]*s3.Object, 0)
var continuationToken *string
for {
out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{
out, err := s.s3Client.ListObjectsV2WithContext(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(s.Config.BucketName),
Prefix: aws.String(s.Config.Prefix),
ContinuationToken: continuationToken,
Expand All @@ -229,7 +228,7 @@ func (s *S3Source) getBucketContent() ([]*s3.Object, error) {
return bucketObjects, nil
}

func (s *S3Source) listPoll() error {
func (s *S3Source) listPoll(ctx context.Context) error {
logger := s.logger.WithField("method", "listPoll")
ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second)
lastObjectDate := time.Now()
Expand All @@ -243,7 +242,7 @@ func (s *S3Source) listPoll() error {
return nil
case <-ticker.C:
newObject := false
bucketObjects, err := s.getBucketContent()
bucketObjects, err := s.getBucketContent(ctx)
if err != nil {
logger.Errorf("Error while getting bucket content: %s", err)
continue
Expand Down Expand Up @@ -325,7 +324,7 @@ func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, erro
}
}

func (s *S3Source) sqsPoll() error {
func (s *S3Source) sqsPoll(ctx context.Context) error {
logger := s.logger.WithField("method", "sqsPoll")
for {
select {
Expand All @@ -335,7 +334,7 @@ func (s *S3Source) sqsPoll() error {
return nil
default:
logger.Trace("Polling SQS queue")
out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{
out, err := s.sqsClient.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
QueueUrl: aws.String(s.Config.SQSName),
MaxNumberOfMessages: aws.Int64(10),
WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ?
Expand Down Expand Up @@ -378,7 +377,7 @@ func (s *S3Source) sqsPoll() error {
}
}

func (s *S3Source) readFile(bucket string, key string) error {
func (s *S3Source) readFile(ctx context.Context, bucket string, key string) error {
//TODO: Handle SSE-C
var scanner *bufio.Scanner

Expand All @@ -388,7 +387,7 @@ func (s *S3Source) readFile(bucket string, key string) error {
"key": key,
})

output, err := s.s3Client.GetObjectWithContext(s.ctx, &s3.GetObjectInput{
output, err := s.s3Client.GetObjectWithContext(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
Expand Down Expand Up @@ -645,24 +644,26 @@ func (s *S3Source) GetName() string {
}

func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error {
var ctx context.Context

s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key)
s.out = out
s.ctx, s.cancel = context.WithCancel(context.Background())
ctx, s.cancel = context.WithCancel(context.Background())
s.Config.UseTimeMachine = true
s.t = t
if s.Config.Key != "" {
err := s.readFile(s.Config.BucketName, s.Config.Key)
err := s.readFile(ctx, s.Config.BucketName, s.Config.Key)
if err != nil {
return err
}
} else {
//No key, get everything in the bucket based on the prefix
objects, err := s.getBucketContent()
objects, err := s.getBucketContent(ctx)
if err != nil {
return err
}
for _, object := range objects {
err := s.readFile(s.Config.BucketName, *object.Key)
err := s.readFile(ctx, s.Config.BucketName, *object.Key)
if err != nil {
return err
}
Expand All @@ -673,26 +674,28 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error
}

func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error {
var ctx context.Context

s.t = t
s.out = out
s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered?
s.ctx, s.cancel = context.WithCancel(context.Background())
ctx, s.cancel = context.WithCancel(context.Background())
s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix)
t.Go(func() error {
s.readManager()
s.readManager(ctx)
return nil
})
if s.Config.PollingMethod == PollMethodSQS {
t.Go(func() error {
err := s.sqsPoll()
err := s.sqsPoll(ctx)
if err != nil {
return err
}
return nil
})
} else {
t.Go(func() error {
err := s.listPoll()
err := s.listPoll(ctx)
if err != nil {
return err
}
Expand Down
Loading
Loading