diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index 7f1fdda56a7..1fbedcf57fd 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -491,7 +491,7 @@ func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notVa } if !notValidOnly { - if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil { + if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil { machines = append(machines, pending...) } } diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index e1231eaa9ec..799b736ccfe 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -9,7 +9,9 @@ import ( func (c *Controller) HeartBeat(gctx *gin.Context) { machineID, _ := getMachineIDFromContext(gctx) - if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { + ctx := gctx.Request.Context() + + if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 491dc9768cd..b7cccd1aa39 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -177,6 +177,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { auth *authInput ) + ctx := c.Request.Context() + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { auth, err = j.authTLS(c) if err != nil { @@ -200,7 +202,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { } } - err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -210,7 +212,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { clientIP := c.ClientIP() if auth.clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -220,7 +222,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication @@ -233,7 +235,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { + if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) log.Errorf("bad user agent from : %s", clientIP) diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 2325b858285..d8c02825312 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -197,8 +197,8 @@ func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine return nbDeleted, nil } -func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX) +func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error { + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } @@ -206,11 +206,11 @@ func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { return nil } -func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { +func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetUpdatedAt(time.Now().UTC()). SetScenarios(scenarios). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine in database: %w", err) } @@ -218,10 +218,10 @@ func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { return nil } -func (c *Client) UpdateMachineIP(ipAddr string, id int) error { +func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetIpAddress(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine IP in database: %w", err) } @@ -229,10 +229,10 @@ func (c *Client) UpdateMachineIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { +func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetVersion(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine version in database: %w", err) } @@ -240,8 +240,8 @@ func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { return nil } -func (c *Client) IsMachineRegistered(machineID string) (bool, error) { - exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX) +func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) { + exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx) if err != nil { return false, err } @@ -257,11 +257,11 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) { return false, nil } -func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) { +func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) { return c.Ent.Machine.Query().Where( machine.Or( machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)), machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)), ), - ).All(c.CTX) + ).All(ctx) }