diff --git a/api/openapi/certs.yml b/api/openapi/certs.yml index 41414fc3bb..b1b10799ed 100644 --- a/api/openapi/certs.yml +++ b/api/openapi/certs.yml @@ -7,6 +7,7 @@ info: paths: /certs: post: + operationId: createCert summary: Creates a certificate for thing description: Creates a certificate for thing tags: @@ -15,15 +16,49 @@ paths: $ref: "#/components/requestBodies/CertReq" responses: '201': - description: Created + $ref: "#/components/responses/CertRes" + '404': + description: Not Found + $ref: "#/components/responses/Error" '400': description: Failed due to malformed JSON. + $ref: "#/components/responses/Error" "401": description: Missing or invalid access token provided. + $ref: "#/components/responses/Error" '500': description: Unexpected server-side error ocurred. - /certs/{certID}: + $ref: "#/components/responses/ServiceError" + get: + operationId: getCerts + summary: Get the requested certificates + description: Get the requested certificates + tags: + - certs + parameters: + - $ref: "#/components/parameters/ThingID" + - $ref: "#/components/parameters/Name" + - $ref: "#/components/parameters/Serial" + - $ref: "#/components/parameters/Limit" + - $ref: "#/components/parameters/Offset" + responses: + '201': + $ref: "#/components/responses/CertRes" + '400': + description: Failed due to malformed query parameters. + $ref: "#/components/responses/Error" + "401": + description: Missing or invalid access token provided. + $ref: "#/components/responses/Error" + '404': + description: Failed to retrieve corresponding certificates. + $ref: "#/components/responses/Error" + '500': + $ref: "#/components/responses/ServiceError" + + /certs/{cert_id}: get: + operationId: getCert summary: Retrieves a certificate description: | Retrieves a certificate for a given cert ID. @@ -44,16 +79,17 @@ paths: '500': $ref: "#/components/responses/ServiceError" delete: - summary: Revokes a certificate + operationId: deleteCert + summary: Delete a certificate description: | - Revokes a certificate for a given cert ID. + Delete a certificate for a given cert ID. tags: - certs parameters: - $ref: "#/components/parameters/CertID" responses: '200': - $ref: "#/components/responses/RevokeRes" + description: Certificate deleted successfully. "401": description: Missing or invalid access token provided. '404': @@ -61,27 +97,134 @@ paths: Failed to revoke corresponding certificate. '500': $ref: "#/components/responses/ServiceError" - /serials/{thingID}: - get: - summary: Retrieves certificates' serial IDs + + /certs/{cert_id}/revoke: + post: + operationId: revokeCert + summary: Revoke a certificate description: | - Retrieves a list of certificates' serial IDs for a given thing ID. + Revoke a certificate for a given cert ID. tags: - certs parameters: - - $ref: "#/components/parameters/ThingID" + - $ref: "#/components/parameters/CertID" responses: '200': - $ref: "#/components/responses/SerialsPageRes" + description: Certificate revoked successfully. '400': description: Failed due to malformed query parameters. "401": description: Missing or invalid access token provided. '404': - description: | - Failed to retrieve corresponding certificates. + description: Corresponding certificate not found. + '500': + $ref: "#/components/responses/ServiceError" + + /certs/{cert_id}/renew: + post: + operationId: renewCert + summary: Renew a certificate + description: | + Renew a certificate for a given cert ID. + tags: + - certs + parameters: + - $ref: "#/components/parameters/CertID" + responses: + '200': + $ref: "#/components/responses/CertRes" + '400': + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + '404': + description: Corresponding certificate not found. + '500': + $ref: "#/components/responses/ServiceError" + + /things/{thing_id}/revoke: + post: + operationId: revokeThingCerts + summary: Revoke certificates' of thing ID + description: | + Revoke a list of certificates' for a given thing ID. + tags: + - certs + parameters: + - $ref: "#/components/parameters/ThingIDPath" + - $ref: "#/components/parameters/Limit" + responses: + '200': + description: All the certificate belongs to thing_id with specified limit are revoked successfully. + '400': + description: Failed due to malformed query parameters. + $ref: "#/components/responses/Error" + "401": + description: Missing or invalid access token provided. + $ref: "#/components/responses/Error" + '404': + description: Failed to retrieve corresponding certificates. + $ref: "#/components/responses/Error" + '500': + $ref: "#/components/responses/ServiceError" + + + /things/{thing_id}/renew: + post: + operationId: renewThingCerts + summary: Renews certificates' of thing ID + description: | + Renews a list of certificates' for a given thing ID. + tags: + - certs + parameters: + - $ref: "#/components/parameters/ThingIDPath" + - $ref: "#/components/parameters/Limit" + responses: + '200': + description: All the certificate belongs to thing_id with specified limit are renewed successfully. + '400': + description: Failed due to malformed query parameters. + $ref: "#/components/responses/Error" + "401": + description: Missing or invalid access token provided. + $ref: "#/components/responses/Error" + '404': + description: Failed to retrieve corresponding certificates. + $ref: "#/components/responses/Error" + '500': + $ref: "#/components/responses/ServiceError" + + /things/{thing_id}: + delete: + operationId: deleteThingCerts + summary: Delete certificates' of a thing ID + description: | + Delete a list of certificates' for a given thing ID. + tags: + - certs + parameters: + - $ref: "#/components/parameters/ThingIDPath" + - $ref: "#/components/parameters/Limit" + responses: + '200': + description: All the certificate belongs to thing_id with specified limit are delete successfully. + '400': + description: Failed due to malformed query parameters. + $ref: "#/components/responses/Error" + "401": + description: Missing or invalid access token provided. + $ref: "#/components/responses/Error" + '404': + description: Failed to retrieve corresponding certificates. + $ref: "#/components/responses/Error" '500': $ref: "#/components/responses/ServiceError" + + + + + /health: get: summary: Retrieves service health check info. @@ -93,54 +236,123 @@ paths: '500': $ref: "#/components/responses/ServiceError" + components: parameters: - ThingID: - name: thingID - description: Thing ID + CertID: + name: cert_id + description: Unique certificate identifier. in: path schema: type: string format: uuid required: true - CertID: - name: certID - description: Serial of certificate + + ThingIDPath: + name: thing_id + description: Unique thing identifier. in: path schema: type: string format: uuid required: true + ThingID: + name: thing_id + description: Unique thing identifier. + in: query + schema: + type: string + format: uuid + required: false + Serial: + name: serial + description: Unique certificate identifier provided by PKI. + in: query + schema: + type: string + required: false + Name: + name: name + description: Name filter. Filtering is performed as a case-insensitive partial match. + in: query + schema: + type: string + format: ulid + required: false + Limit: + name: limit + description: Size of the subset to retrieve. + in: query + schema: + type: integer + default: 10 + maximum: 100 + minimum: 1 + required: false + Offset: + name: offset + description: Number of items to skip during retrieval. + in: query + schema: + type: integer + default: 0 + minimum: 0 + required: false + schemas: + Err: + type: object + properties: + error: + type: string + description: contains details of the error Cert: type: object properties: + id: + type: string + format: uuid + description: Identification UUID of the certificate + name: + type: string + format: uuid + description: Name of the certificate. + owner_id: + type: string + format: uuid + description: ID of the corresponding Mainflux Thing owner. thing_id: type: string format: uuid description: Corresponding Mainflux Thing ID. - client_cert: + serial: + type: string + description: Certificate serial + certificate: type: string - description: Client Certificate. - client_key: + description: Certificate. + private_key: type: string - description: Key for the client_cert. + description: Key for the Certificate. + ca_chain: + type: string + description: CA Chain contains root CA certificate and all the intermediate CA certificates. issuing_ca: type: string description: CA Certificate that is used to issue client certs, usually intermediate. - serial: + ttl: type: string - description: Certificate serial + description: Certificate validity duration. expire: type: string - description: Certificate expiry date - Serial: - type: object - properties: - serial: + format: timestamp + description: Certificate expiry timestamp + revocation: type: string - description: Certificate serial + format: timestamp + description: Certificate revoked timestamp + CertsPage: type: object properties: @@ -159,31 +371,12 @@ components: limit: type: integer description: Maximum number of items to return in one page. - SerialsPage: + RemainingCount: type: object properties: - serials: - type: array - description: Certificate serials IDs. - minItems: 0 - uniqueItems: true - items: - type: string - total: - type: integer - description: Total number of items. - offset: - type: integer - description: Number of items to skip during retrieval. - limit: - type: integer - description: Maximum number of items to return in one page. - Revoke: - type: object - properties: - revocation_time: + remaining: type: string - description: Certificate revocation time + description: remaining certificate left after the operation. requestBodies: CertReq: @@ -196,17 +389,30 @@ components: schema: type: object required: + - name - thing_id - ttl properties: + name: + type: string + example: thing_1_cert_001 + description: Name of the certificate, A thing can have multiple certificate thing_id: type: string format: uuid + description: ID of the thing for which certificate is required ttl: type: string example: "10h" + description: Certificate validity responses: + Error: + description: Invalid request + content: + application/json: + schema: + $ref: "#/components/schemas/Err" ServiceError: description: Unexpected server-side error occurred. CertRes: @@ -221,18 +427,6 @@ components: application/json: schema: $ref: "#/components/schemas/CertsPage" - SerialsPageRes: - description: Serials page. - content: - application/json: - schema: - $ref: "#/components/schemas/SerialsPage" - RevokeRes: - description: Certificate revoked. - content: - application/json: - schema: - $ref: "#/components/schemas/Revoke" HealthRes: description: Service Health Check. content: diff --git a/bootstrap/service.go b/bootstrap/service.go index fc5675b4a6..9ebdf8c5d8 100644 --- a/bootstrap/service.go +++ b/bootstrap/service.go @@ -28,6 +28,8 @@ var ( // ErrBootstrap indicates error in getting bootstrap configuration. ErrBootstrap = errors.New("failed to read bootstrap configuration") + // ErrUpdateCert indicates error in updating the certificates + ErrUpdateCert = errors.New("failed to update cert") errAddBootstrap = errors.New("failed to add bootstrap configuration") errUpdateConnections = errors.New("failed to update connections") @@ -40,7 +42,6 @@ var ( errDisconnectThing = errors.New("failed to disconnect thing") errCheckChannels = errors.New("failed to check if channels exists") errConnectionChannels = errors.New("failed to check channels connections") - errUpdateCert = errors.New("failed to update cert") ) var _ Service = (*bootstrapService)(nil) @@ -191,7 +192,7 @@ func (bs bootstrapService) UpdateCert(ctx context.Context, token, thingID, clien return err } if err := bs.configs.UpdateCert(owner, thingID, clientCert, clientKey, caCert); err != nil { - return errors.Wrap(errUpdateCert, err) + return errors.Wrap(ErrUpdateCert, err) } return nil } diff --git a/certs/api/endpoint.go b/certs/api/endpoint.go index 9e6828f760..e487f747bf 100644 --- a/certs/api/endpoint.go +++ b/certs/api/endpoint.go @@ -16,30 +16,23 @@ func issueCert(svc certs.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return nil, err } - res, err := svc.IssueCert(ctx, req.token, req.ThingID, req.TTL) + res, err := svc.IssueCert(ctx, req.token, req.ThingID, req.Name, req.TTL) if err != nil { return certsRes{}, err } - return certsRes{ - CertSerial: res.Serial, - ThingID: res.ThingID, - ClientCert: res.ClientCert, - ClientKey: res.ClientKey, - Expiration: res.Expire, - created: true, - }, nil + return CertToCertResponse(res, true), nil } } -func listSerials(svc certs.Service) endpoint.Endpoint { +func listCerts(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(listReq) if err := req.validate(); err != nil { return nil, err } - page, err := svc.ListSerials(ctx, req.token, req.thingID, req.offset, req.limit) + page, err := svc.ListCerts(ctx, req.token, req.certID, req.thingID, req.serial, req.name, req.certStatus, req.offset, req.limit) if err != nil { return certsPageRes{}, err } @@ -53,9 +46,7 @@ func listSerials(svc certs.Service) endpoint.Endpoint { } for _, cert := range page.Certs { - cr := certsRes{ - CertSerial: cert.Serial, - } + cr := CertToCertResponse(cert, true) res.Certs = append(res.Certs, cr) } return res, nil @@ -64,39 +55,99 @@ func listSerials(svc certs.Service) endpoint.Endpoint { func viewCert(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(viewReq) + req := request.(viewRevokeRenewRemoveReq) if err := req.validate(); err != nil { return nil, err } - cert, err := svc.ViewCert(ctx, req.token, req.serialID) + cert, err := svc.ViewCert(ctx, req.token, req.certID) if err != nil { return certsPageRes{}, err } - certRes := certsRes{ - CertSerial: cert.Serial, - ThingID: cert.ThingID, - ClientCert: cert.ClientCert, - Expiration: cert.Expire, + return CertToCertResponse(cert, false), nil + } +} + +func revokeCert(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(viewRevokeRenewRemoveReq) + if err := req.validate(); err != nil { + return nil, err } + return emptyCertRes{}, svc.RevokeCert(ctx, req.token, req.certID) + } +} - return certRes, nil +func renewCert(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(viewRevokeRenewRemoveReq) + if err := req.validate(); err != nil { + return nil, err + } + cert, err := svc.RenewCert(ctx, req.token, req.certID) + if err != nil { + return certsPageRes{}, err + } + return CertToCertResponse(cert, false), nil } } -func revokeCert(svc certs.Service) endpoint.Endpoint { +func removeCert(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(viewRevokeRenewRemoveReq) + if err := req.validate(); err != nil { + return nil, err + } + if err := svc.RemoveCert(ctx, req.token, req.certID); err != nil { + return nil, err + } + return emptyCertRes{}, nil + + } +} + +func revokeThingCerts(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokeRenewRemoveThingIDReq) + if err := req.validate(); err != nil { + return nil, err + } + c, err := svc.RevokeThingCerts(ctx, req.token, req.thingID, req.limit) + if err != nil { + return nil, err + } + rc := map[string]interface{}{"remaining": c} + return rc, nil + } +} + +func renewThingCerts(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokeRenewRemoveThingIDReq) + if err := req.validate(); err != nil { + return nil, err + } + c, err := svc.RenewThingCerts(ctx, req.token, req.thingID, req.limit) + if err != nil { + return nil, err + } + rc := map[string]interface{}{"remaining": c} + return rc, nil + } +} + +func removeThingCerts(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(revokeReq) + req := request.(revokeRenewRemoveThingIDReq) if err := req.validate(); err != nil { return nil, err } - res, err := svc.RevokeCert(ctx, req.token, req.certID) + c, err := svc.RemoveThingCerts(ctx, req.token, req.thingID, req.limit) if err != nil { return nil, err } - return revokeCertsRes{ - RevocationTime: res.RevocationTime, - }, nil + rc := map[string]interface{}{"remaining": c} + return rc, nil } } diff --git a/certs/api/logging.go b/certs/api/logging.go index ae7dde256b..e734f79a44 100644 --- a/certs/api/logging.go +++ b/certs/api/logging.go @@ -26,7 +26,7 @@ func NewLoggingMiddleware(svc certs.Service, logger log.Logger) certs.Service { return &loggingMiddleware{logger, svc} } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, token, thingID, ttl string) (c certs.Cert, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, token, thingID, name, ttl string) (c certs.Cert, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for token: %s and thing: %s took %s to complete", token, thingID, time.Since(begin)) if err != nil { @@ -36,12 +36,12 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, token, thingID, ttl lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.IssueCert(ctx, token, thingID, ttl) + return lm.svc.IssueCert(ctx, token, thingID, name, ttl) } -func (lm *loggingMiddleware) ListCerts(ctx context.Context, token, thingID string, offset, limit uint64) (cp certs.Page, err error) { +func (lm *loggingMiddleware) ListCerts(ctx context.Context, token, certID, thingID, serial, name string, status certs.Status, offset, limit uint64) (cp certs.Page, err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method list_certs for token: %s and thing id: %s took %s to complete", token, thingID, time.Since(begin)) + message := fmt.Sprintf("Method list_certs for token: %s, cert ID: %s thing id: %s serial: %s name: %s took %s to complete", token, certID, thingID, serial, name, time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return @@ -49,12 +49,12 @@ func (lm *loggingMiddleware) ListCerts(ctx context.Context, token, thingID strin lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.ListCerts(ctx, token, thingID, offset, limit) + return lm.svc.ListCerts(ctx, token, certID, thingID, serial, name, status, offset, limit) } -func (lm *loggingMiddleware) ListSerials(ctx context.Context, token, thingID string, offset, limit uint64) (cp certs.Page, err error) { +func (lm *loggingMiddleware) ViewCert(ctx context.Context, token, certID string) (c certs.Cert, err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method list_serials for token: %s and thing id: %s took %s to complete", token, thingID, time.Since(begin)) + message := fmt.Sprintf("Method view_cert for token: %s and certificate id: %s took %s to complete", token, certID, time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return @@ -62,12 +62,12 @@ func (lm *loggingMiddleware) ListSerials(ctx context.Context, token, thingID str lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.ListSerials(ctx, token, thingID, offset, limit) + return lm.svc.ViewCert(ctx, token, certID) } -func (lm *loggingMiddleware) ViewCert(ctx context.Context, token, serialID string) (c certs.Cert, err error) { +func (lm *loggingMiddleware) RevokeCert(ctx context.Context, token, certID string) (err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method view_cert for token: %s and serial id %s took %s to complete", token, serialID, time.Since(begin)) + message := fmt.Sprintf("Method revoke_cert for token: %s and certificate id: %s took %s to complete", token, certID, time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return @@ -75,12 +75,25 @@ func (lm *loggingMiddleware) ViewCert(ctx context.Context, token, serialID strin lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.ViewCert(ctx, token, serialID) + return lm.svc.RevokeCert(ctx, token, certID) } -func (lm *loggingMiddleware) RevokeCert(ctx context.Context, token, thingID string) (c certs.Revoke, err error) { +func (lm *loggingMiddleware) RenewCert(ctx context.Context, token, certID string) (c certs.Cert, err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method revoke_cert for token: %s and thing: %s took %s to complete", token, thingID, time.Since(begin)) + message := fmt.Sprintf("Method renew_certs for token: %s and certificate id: %s took %s to complete", token, certID, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors.", message)) + }(time.Now()) + + return lm.svc.RenewCert(ctx, token, certID) +} + +func (lm *loggingMiddleware) RemoveCert(ctx context.Context, token, certID string) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method renew_certs for token: %s and certificate id: %s took %s to complete", token, certID, time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return @@ -88,5 +101,44 @@ func (lm *loggingMiddleware) RevokeCert(ctx context.Context, token, thingID stri lm.logger.Info(fmt.Sprintf("%s without errors.", message)) }(time.Now()) - return lm.svc.RevokeCert(ctx, token, thingID) + return lm.svc.RemoveCert(ctx, token, certID) +} + +func (lm *loggingMiddleware) RevokeThingCerts(ctx context.Context, token, thingID string, limit int64) (c uint64, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method revoke_cert for token: %s and thing: %s took %s to complete", token, thingID, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors. %d remaining certificates to revoke ", message, c)) + }(time.Now()) + + return lm.svc.RevokeThingCerts(ctx, token, thingID, limit) +} + +func (lm *loggingMiddleware) RenewThingCerts(ctx context.Context, token, thingID string, limit int64) (c uint64, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method renew_certs token: %s and thing: %s took %s to complete", token, thingID, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors. %d remaining certificates to renew ", message, c)) + }(time.Now()) + + return lm.svc.RenewThingCerts(ctx, token, thingID, limit) +} + +func (lm *loggingMiddleware) RemoveThingCerts(ctx context.Context, token, thingID string, limit int64) (c uint64, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method remove_certs for token: %s and thing: %s took %s to complete", token, thingID, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors. %d remaining certificates to remove ", message, c)) + }(time.Now()) + + return lm.svc.RemoveThingCerts(ctx, token, thingID, limit) } diff --git a/certs/api/metrics.go b/certs/api/metrics.go index 266fc0e0ab..68679a6f1a 100644 --- a/certs/api/metrics.go +++ b/certs/api/metrics.go @@ -30,31 +30,22 @@ func MetricsMiddleware(svc certs.Service, counter metrics.Counter, latency metri } } -func (ms *metricsMiddleware) IssueCert(ctx context.Context, token, thingID, ttl string) (certs.Cert, error) { +func (ms *metricsMiddleware) IssueCert(ctx context.Context, token, thingID, name, ttl string) (certs.Cert, error) { defer func(begin time.Time) { ms.counter.With("method", "issue_cert").Add(1) ms.latency.With("method", "issue_cert").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.IssueCert(ctx, token, thingID, ttl) + return ms.svc.IssueCert(ctx, token, thingID, name, ttl) } -func (ms *metricsMiddleware) ListCerts(ctx context.Context, token, thingID string, offset, limit uint64) (certs.Page, error) { +func (ms *metricsMiddleware) ListCerts(ctx context.Context, token, certID, thingID, serial, name string, status certs.Status, offset, limit uint64) (certs.Page, error) { defer func(begin time.Time) { ms.counter.With("method", "list_certs").Add(1) ms.latency.With("method", "list_certs").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.ListCerts(ctx, token, thingID, offset, limit) -} - -func (ms *metricsMiddleware) ListSerials(ctx context.Context, token, thingID string, offset, limit uint64) (certs.Page, error) { - defer func(begin time.Time) { - ms.counter.With("method", "list_serials").Add(1) - ms.latency.With("method", "list_serials").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return ms.svc.ListSerials(ctx, token, thingID, offset, limit) + return ms.svc.ListCerts(ctx, token, certID, thingID, serial, name, status, offset, limit) } func (ms *metricsMiddleware) ViewCert(ctx context.Context, token, serialID string) (certs.Cert, error) { @@ -66,11 +57,56 @@ func (ms *metricsMiddleware) ViewCert(ctx context.Context, token, serialID strin return ms.svc.ViewCert(ctx, token, serialID) } -func (ms *metricsMiddleware) RevokeCert(ctx context.Context, token, thingID string) (certs.Revoke, error) { +func (ms *metricsMiddleware) RevokeCert(ctx context.Context, token, certID string) error { defer func(begin time.Time) { ms.counter.With("method", "revoke_cert").Add(1) ms.latency.With("method", "revoke_cert").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.RevokeCert(ctx, token, thingID) + return ms.svc.RevokeCert(ctx, token, certID) +} + +func (ms *metricsMiddleware) RenewCert(ctx context.Context, token, certID string) (certs.Cert, error) { + defer func(begin time.Time) { + ms.counter.With("method", "renew_cert").Add(1) + ms.latency.With("method", "renew_cert").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RenewCert(ctx, token, certID) +} + +func (ms *metricsMiddleware) RemoveCert(ctx context.Context, token, certID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "remove_cert").Add(1) + ms.latency.With("method", "remove_cert").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RemoveCert(ctx, token, certID) +} + +func (ms *metricsMiddleware) RevokeThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_thing_cert").Add(1) + ms.latency.With("method", "revoke_thing_cert").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RevokeThingCerts(ctx, token, thingID, limit) +} + +func (ms *metricsMiddleware) RenewThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) { + defer func(begin time.Time) { + ms.counter.With("method", "renew_cert").Add(1) + ms.latency.With("method", "renew_cert").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RenewThingCerts(ctx, token, thingID, limit) +} + +func (ms *metricsMiddleware) RemoveThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) { + defer func(begin time.Time) { + ms.counter.With("method", "remove_cert").Add(1) + ms.latency.With("method", "remove_cert").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RemoveThingCerts(ctx, token, thingID, limit) } diff --git a/certs/api/requests.go b/certs/api/requests.go index cfcd9dd8d7..8575cd039e 100644 --- a/certs/api/requests.go +++ b/certs/api/requests.go @@ -6,6 +6,7 @@ package api import ( "time" + "github.com/mainflux/mainflux/certs" "github.com/mainflux/mainflux/internal/apiutil" ) @@ -13,6 +14,7 @@ const maxLimitSize = 100 type addCertsReq struct { token string + Name string `json:"name"` ThingID string `json:"thing_id"` TTL string `json:"ttl"` } @@ -23,7 +25,7 @@ func (req addCertsReq) validate() error { } if req.ThingID == "" { - return apiutil.ErrMissingID + return apiutil.ErrMissingThingID } if req.TTL == "" { @@ -38,10 +40,15 @@ func (req addCertsReq) validate() error { } type listReq struct { - thingID string - token string - offset uint64 - limit uint64 + certID string + thingID string + serial string + name string + status string + token string + offset uint64 + limit uint64 + certStatus certs.Status } func (req *listReq) validate() error { @@ -51,37 +58,42 @@ func (req *listReq) validate() error { if req.limit > maxLimitSize { return apiutil.ErrLimitSize } + cs, ok := certs.StringToStatus[req.status] + if !ok { + return apiutil.ErrInvalidCertData + } + req.certStatus = cs return nil } -type viewReq struct { - serialID string - token string +type viewRevokeRenewRemoveReq struct { + certID string + token string } -func (req *viewReq) validate() error { +func (req *viewRevokeRenewRemoveReq) validate() error { if req.token == "" { return apiutil.ErrBearerToken } - if req.serialID == "" { + if req.certID == "" { return apiutil.ErrMissingID } return nil } -type revokeReq struct { - token string - certID string +type revokeRenewRemoveThingIDReq struct { + thingID string + token string + limit int64 } -func (req *revokeReq) validate() error { +func (req *revokeRenewRemoveThingIDReq) validate() error { if req.token == "" { return apiutil.ErrBearerToken } - - if req.certID == "" { - return apiutil.ErrMissingID + if req.thingID == "" { + return apiutil.ErrMissingThingID } return nil diff --git a/certs/api/responses.go b/certs/api/responses.go index 213be2b9d2..a5f71af98f 100644 --- a/certs/api/responses.go +++ b/certs/api/responses.go @@ -6,12 +6,14 @@ package api import ( "net/http" "time" + + "github.com/mainflux/mainflux/certs" ) type pageRes struct { Total uint64 `json:"total"` Offset uint64 `json:"offset"` - Limit uint64 `json:"limit"` + Limit int64 `json:"limit"` } type certsPageRes struct { @@ -20,16 +22,19 @@ type certsPageRes struct { } type certsRes struct { - ThingID string `json:"thing_id"` - ClientCert string `json:"client_cert"` - ClientKey string `json:"client_key"` - CertSerial string `json:"cert_serial"` - Expiration time.Time `json:"expiration"` - created bool -} - -type revokeCertsRes struct { - RevocationTime time.Time `json:"revocation_time"` + ID string `json:"id"` + Name string `json:"name"` + OwnerID string `json:"owner_id"` + ThingID string `json:"thing_id"` + Serial string `json:"serial"` + Certificate string `json:"certificate"` + PrivateKey string `json:"private_key"` + CAChain string `json:"ca_chain"` + IssuingCA string `json:"issuing_ca"` + TTL string `json:"ttl"` + Expire time.Time `json:"expire"` + Revocation string `json:"revocation"` + created bool } func (res certsPageRes) Code() int { @@ -60,14 +65,38 @@ func (res certsRes) Empty() bool { return false } -func (res revokeCertsRes) Code() int { +func CertToCertResponse(cert certs.Cert, created bool) certsRes { + rev := "" + if !cert.Revocation.IsZero() { + rev = cert.Revocation.Format(time.RFC3339) + } + return certsRes{ + ID: cert.ID, + Name: cert.Name, + OwnerID: cert.OwnerID, + ThingID: cert.ThingID, + Serial: cert.Serial, + Certificate: cert.Certificate, + PrivateKey: cert.PrivateKey, + CAChain: cert.CAChain, + IssuingCA: cert.IssuingCA, + TTL: cert.TTL, + Expire: cert.Expire, + Revocation: rev, + created: created, + } +} + +type emptyCertRes struct{} + +func (res emptyCertRes) Code() int { return http.StatusOK } -func (res revokeCertsRes) Headers() map[string]string { +func (res emptyCertRes) Headers() map[string]string { return map[string]string{} } -func (res revokeCertsRes) Empty() bool { - return false +func (res emptyCertRes) Empty() bool { + return true } diff --git a/certs/api/transport.go b/certs/api/transport.go index 82b5002795..7e2ad1dfa4 100644 --- a/certs/api/transport.go +++ b/certs/api/transport.go @@ -22,8 +22,15 @@ const ( contentType = "application/json" offsetKey = "offset" limitKey = "limit" - defOffset = 0 - defLimit = 10 + certKey = "cert_id" + thingKey = "thing_id" + nameKey = "name" + serialKey = "serial" + statusKey = "status" + + defStatus = "all" + defOffset = 0 + defLimit = 10 ) // MakeHandler returns a HTTP handler for API endpoints. @@ -41,27 +48,62 @@ func MakeHandler(svc certs.Service, logger logger.Logger) http.Handler { opts..., )) - r.Get("/certs/:certId", kithttp.NewServer( + r.Get("/certs/:certID", kithttp.NewServer( viewCert(svc), - decodeViewCert, + decodeViewRevokeRenewRemoveCerts, encodeResponse, opts..., )) - r.Delete("/certs/:certId", kithttp.NewServer( + r.Post("/certs/:certID/revoke", kithttp.NewServer( revokeCert(svc), - decodeRevokeCerts, + decodeViewRevokeRenewRemoveCerts, + encodeResponse, + opts..., + )) + + r.Post("/certs/:certID/renew", kithttp.NewServer( + renewCert(svc), + decodeViewRevokeRenewRemoveCerts, encodeResponse, opts..., )) - r.Get("/serials/:thingId", kithttp.NewServer( - listSerials(svc), + r.Delete("/certs/:certID", kithttp.NewServer( + removeCert(svc), + decodeViewRevokeRenewRemoveCerts, + encodeResponse, + opts..., + )) + + r.Get("/certs", kithttp.NewServer( + listCerts(svc), decodeListCerts, encodeResponse, opts..., )) + r.Post("/things/:thingID/revoke", kithttp.NewServer( + revokeThingCerts(svc), + decodeRevokeRenewRemoveThing, + encodeResponse, + opts..., + )) + + r.Post("/things/:thingID/renew", kithttp.NewServer( + renewThingCerts(svc), + decodeRevokeRenewRemoveThing, + encodeResponse, + opts..., + )) + + r.Delete("/things/:thingID", kithttp.NewServer( + removeThingCerts(svc), + decodeRevokeRenewRemoveThing, + encodeResponse, + opts..., + )) + r.Handle("/metrics", promhttp.Handler()) r.GetFunc("/health", mainflux.Health("certs")) @@ -96,24 +138,44 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { return nil, err } + certID, err := apiutil.ReadStringQuery(r, certKey, "") + if err != nil { + return nil, err + } + + thingID, err := apiutil.ReadStringQuery(r, thingKey, "") + if err != nil { + return nil, err + } + + serial, err := apiutil.ReadStringQuery(r, serialKey, "") + if err != nil { + return nil, err + } + + name, err := apiutil.ReadStringQuery(r, nameKey, "") + if err != nil { + return nil, err + } + + status, err := apiutil.ReadStringQuery(r, statusKey, defStatus) + if err != nil { + return nil, err + } + req := listReq{ token: apiutil.ExtractBearerToken(r), - thingID: bone.GetValue(r, "thingId"), + certID: certID, + thingID: thingID, + serial: serial, + status: status, + name: name, limit: l, offset: o, } return req, nil } -func decodeViewCert(_ context.Context, r *http.Request) (interface{}, error) { - req := viewReq{ - token: apiutil.ExtractBearerToken(r), - serialID: bone.GetValue(r, "certId"), - } - - return req, nil -} - func decodeCerts(_ context.Context, r *http.Request) (interface{}, error) { if r.Header.Get("Content-Type") != contentType { return nil, errors.ErrUnsupportedContentType @@ -127,10 +189,25 @@ func decodeCerts(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } -func decodeRevokeCerts(_ context.Context, r *http.Request) (interface{}, error) { - req := revokeReq{ +func decodeViewRevokeRenewRemoveCerts(_ context.Context, r *http.Request) (interface{}, error) { + req := viewRevokeRenewRemoveReq{ token: apiutil.ExtractBearerToken(r), - certID: bone.GetValue(r, "certId"), + certID: bone.GetValue(r, "certID"), + } + + return req, nil +} + +func decodeRevokeRenewRemoveThing(_ context.Context, r *http.Request) (interface{}, error) { + l, err := apiutil.ReadIntQuery(r, limitKey, defLimit) + if err != nil { + return nil, err + } + + req := revokeRenewRemoveThingIDReq{ + token: apiutil.ExtractBearerToken(r), + thingID: bone.GetValue(r, "thingID"), + limit: l, } return req, nil @@ -138,24 +215,52 @@ func decodeRevokeCerts(_ context.Context, r *http.Request) (interface{}, error) func encodeError(_ context.Context, err error, w http.ResponseWriter) { switch { - case errors.Contains(err, errors.ErrAuthentication), - err == apiutil.ErrBearerToken: + + case err == apiutil.ErrBearerToken: w.WriteHeader(http.StatusUnauthorized) - case errors.Contains(err, errors.ErrUnsupportedContentType): - w.WriteHeader(http.StatusUnsupportedMediaType) - case errors.Contains(err, errors.ErrMalformedEntity), - err == apiutil.ErrMissingID, + + case err == apiutil.ErrMissingID, err == apiutil.ErrMissingCertData, - err == apiutil.ErrInvalidCertData, + err == apiutil.ErrLimitSize, + err == apiutil.ErrOffsetSize, err == apiutil.ErrLimitSize: w.WriteHeader(http.StatusBadRequest) + + case errors.Contains(err, errors.ErrNotFound): + w.WriteHeader(http.StatusNotFound) + err = errors.ErrNotFound + + case errors.Contains(err, errors.ErrAuthentication): + w.WriteHeader(http.StatusUnauthorized) + err = errors.ErrAuthentication + + case errors.Contains(err, errors.ErrUnsupportedContentType): + w.WriteHeader(http.StatusUnsupportedMediaType) + err = errors.ErrUnsupportedContentType + + case errors.Contains(err, errors.ErrMalformedEntity): + w.WriteHeader(http.StatusBadRequest) + err = errors.ErrMalformedEntity + case errors.Contains(err, errors.ErrConflict): w.WriteHeader(http.StatusConflict) + err = errors.ErrConflict + + case errors.Contains(err, errors.ErrCreateEntity): + w.WriteHeader(http.StatusInternalServerError) + err = errors.ErrCreateEntity + + case errors.Contains(err, errors.ErrViewEntity): + w.WriteHeader(http.StatusInternalServerError) + err = errors.ErrViewEntity + + case errors.Contains(err, errors.ErrUpdateEntity): + w.WriteHeader(http.StatusInternalServerError) + err = errors.ErrUpdateEntity - case errors.Contains(err, errors.ErrCreateEntity), - errors.Contains(err, errors.ErrViewEntity), - errors.Contains(err, errors.ErrRemoveEntity): + case errors.Contains(err, errors.ErrRemoveEntity): w.WriteHeader(http.StatusInternalServerError) + err = errors.ErrRemoveEntity default: w.WriteHeader(http.StatusInternalServerError) diff --git a/certs/certs.go b/certs/certs.go index 661a09b08c..bd3ea560b1 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -3,30 +3,57 @@ package certs -import "context" +import ( + "context" +) // ConfigsPage contains page related metadata as well as list type Page struct { Total uint64 Offset uint64 - Limit uint64 + Limit int64 Certs []Cert } +type Status int + +const ( + AllCerts Status = iota + ActiveCerts + RevokedCerts +) + +var StringToStatus = map[string]Status{ + "all": AllCerts, + "active": ActiveCerts, + "revoke": RevokedCerts, +} + // Repository specifies a Config persistence API. type Repository interface { // Save saves cert for thing into database - Save(ctx context.Context, cert Cert) (string, error) + Save(ctx context.Context, cert Cert) error + + // Retrieve issued certificates for given owner ID with given any one of the following parameter + // certificate id , certificate name, thing ID and certificate serial + // If all the parameter given, all the condition are added in WHERE CLAUSE with AND condition + // Example to retrieve only certificate with ID Retrieve(ctx, ownerID, certID, "", "", "", AllCerts, 0, 1) + // Example to retrieve by Thing ID Retrieve(ctx, ownerID, "", thingID, "", "", AllCerts, 0, 10) + // Example to retrieve only certificate with serial number Retrieve(ctx, ownerID, "", "", "", serial, AllCerts, 0, 1) + Retrieve(ctx context.Context, ownerID, certID, thingID, serial, name string, status Status, offset uint64, limit int64) (Page, error) + + // RetrieveCount get count of certificate revoked if revokeCount parameter is true and also count of certificate not revoked if revokeCount parameter is false + RetrieveCount(ctx context.Context, ownerID, certID, thingID, serial, name string, status Status) (uint64, error) - // RetrieveAll retrieve issued certificates for given owner ID - RetrieveAll(ctx context.Context, ownerID string, offset, limit uint64) (Page, error) + // Update certificate from DB for a given certificate ID + Update(ctx context.Context, ownerID string, cert Cert) error - // Remove removes certificate from DB for a given thing ID - Remove(ctx context.Context, ownerID, thingID string) error + // Remove removes certificate from DB for a given certificate ID + Remove(ctx context.Context, ownerID, certID string) error - // RetrieveByThing retrieves issued certificates for a given thing ID - RetrieveByThing(ctx context.Context, ownerID, thingID string, offset, limit uint64) (Page, error) + // RetrieveThingCerts retrieves all the certificate for the given thing ID , which doesn't required owner ID, used for thing removed event stream handler + RetrieveThingCerts(ctx context.Context, thingID string) (Page, error) - // RetrieveBySerial retrieves a certificate for a given serial ID - RetrieveBySerial(ctx context.Context, ownerID, serialID string) (Cert, error) + // RemoveThingCerts removes all the certificate for the given thing ID , which doesn't required owner ID, used for thing removed event stream handler + RemoveThingCerts(ctx context.Context, thingID string) error } diff --git a/certs/eventhandlers/things.go b/certs/eventhandlers/things.go new file mode 100644 index 0000000000..21081c6c77 --- /dev/null +++ b/certs/eventhandlers/things.go @@ -0,0 +1,63 @@ +package eventhandlers + +import ( + "context" + + "github.com/mainflux/mainflux/certs" + "github.com/mainflux/mainflux/certs/pki" + thingsEvent "github.com/mainflux/mainflux/internal/clients/events/things" + "github.com/mainflux/mainflux/pkg/errors" +) + +type things struct { + pki pki.Agent + repo certs.Repository +} + +var _ thingsEvent.EventHandler = (*things)(nil) + +func NewThingsEventHandlers(repo certs.Repository, pki pki.Agent) thingsEvent.EventHandler { + return &things{repo: repo, pki: pki} +} + +func (teh *things) ThingCreated(ctx context.Context, cte thingsEvent.CreateThingEvent) error { + return nil +} +func (teh *things) ThingUpdated(ctx context.Context, ute thingsEvent.UpdateThingEvent) error { + return nil +} +func (teh *things) ThingRemoved(ctx context.Context, rte thingsEvent.RemoveThingEvent) error { + cp, err := teh.repo.RetrieveThingCerts(ctx, rte.ID) + if err != nil { + return err + } + + // create async thing event handler with go routine and return error via channels + var retErr error + for _, cert := range cp.Certs { + _, err := teh.pki.Revoke(cert.Serial) + if err != nil { + retErr = errors.Wrap(retErr, err) + } + } + err = teh.repo.RemoveThingCerts(ctx, rte.ID) + if err != nil { + retErr = errors.Wrap(retErr, err) + } + return retErr +} +func (teh *things) ChannelCreated(ctx context.Context, cce thingsEvent.CreateChannelEvent) error { + return nil +} +func (teh *things) ChannelUpdated(ctx context.Context, uce thingsEvent.UpdateChannelEvent) error { + return nil +} +func (teh *things) ChannelRemoved(ctx context.Context, rce thingsEvent.RemoveChannelEvent) error { + return nil +} +func (teh *things) ThingConnected(ctx context.Context, cte thingsEvent.ConnectThingEvent) error { + return nil +} +func (teh *things) ThingDisconnected(ctx context.Context, dte thingsEvent.DisconnectThingEvent) error { + return nil +} diff --git a/certs/mocks/certs.go b/certs/mocks/certs.go index 01f75af0fb..908cc02be5 100644 --- a/certs/mocks/certs.go +++ b/certs/mocks/certs.go @@ -16,28 +16,23 @@ var _ certs.Repository = (*certsRepoMock)(nil) type certsRepoMock struct { mu sync.Mutex counter uint64 - certsBySerial map[string]certs.Cert + certsByID map[string]certs.Cert certsByThingID map[string]map[string][]certs.Cert } // NewCertsRepository creates in-memory certs repository. func NewCertsRepository() certs.Repository { return &certsRepoMock{ - certsBySerial: make(map[string]certs.Cert), + certsByID: make(map[string]certs.Cert), certsByThingID: make(map[string]map[string][]certs.Cert), } } -func (c *certsRepoMock) Save(ctx context.Context, cert certs.Cert) (string, error) { +func (c *certsRepoMock) Save(ctx context.Context, cert certs.Cert) error { c.mu.Lock() defer c.mu.Unlock() - crt := certs.Cert{ - OwnerID: cert.OwnerID, - ThingID: cert.ThingID, - Serial: cert.Serial, - Expire: cert.Expire, - } + crt := cert _, ok := c.certsByThingID[cert.OwnerID][cert.ThingID] switch ok { @@ -49,12 +44,12 @@ func (c *certsRepoMock) Save(ctx context.Context, cert certs.Cert) (string, erro c.certsByThingID[cert.OwnerID][cert.ThingID] = append(c.certsByThingID[cert.OwnerID][cert.ThingID], crt) } - c.certsBySerial[cert.Serial] = crt + c.certsByID[cert.ID] = crt c.counter++ - return cert.Serial, nil + return nil } -func (c *certsRepoMock) RetrieveAll(ctx context.Context, ownerID string, offset, limit uint64) (certs.Page, error) { +func (c *certsRepoMock) Retrieve(ctx context.Context, ownerID, certID, name, thingID, serial string, status certs.Status, offset uint64, limit int64) (certs.Page, error) { c.mu.Lock() defer c.mu.Unlock() if limit <= 0 { @@ -69,69 +64,58 @@ func (c *certsRepoMock) RetrieveAll(ctx context.Context, ownerID string, offset, var crts []certs.Cert for _, tc := range oc { for i, v := range tc { - if uint64(i) >= offset && uint64(i) < offset+limit { + + switch limit >= 0 { + case true: + if uint64(i) >= offset && uint64(i) < offset+uint64(limit) { + crts = append(crts, v) + } + default: crts = append(crts, v) + } + } } + total, err := c.RetrieveCount(ctx, ownerID, certID, name, thingID, serial, status) + if err != nil { + return certs.Page{}, err + } + page := certs.Page{ Certs: crts, - Total: c.counter, + Total: total, Offset: offset, Limit: limit, } return page, nil } -func (c *certsRepoMock) Remove(ctx context.Context, ownerID, serial string) error { +func (c *certsRepoMock) RetrieveCount(ctx context.Context, ownerID, certID, name, thingID, serial string, status certs.Status) (uint64, error) { + return c.counter, nil +} + +func (c *certsRepoMock) Remove(ctx context.Context, ownerID, certID string) error { c.mu.Lock() defer c.mu.Unlock() - crt, ok := c.certsBySerial[serial] + crt, ok := c.certsByID[certID] if !ok { return errors.ErrNotFound } - delete(c.certsBySerial, crt.Serial) + delete(c.certsByID, crt.ID) delete(c.certsByThingID, crt.ThingID) return nil } -func (c *certsRepoMock) RetrieveByThing(ctx context.Context, ownerID, thingID string, offset, limit uint64) (certs.Page, error) { - c.mu.Lock() - defer c.mu.Unlock() - if limit <= 0 { - return certs.Page{}, nil - } - - cs, ok := c.certsByThingID[ownerID][thingID] - if !ok { - return certs.Page{}, errors.ErrNotFound - } - - var crts []certs.Cert - for i, v := range cs { - if uint64(i) >= offset && uint64(i) < offset+limit { - crts = append(crts, v) - } - } - - page := certs.Page{ - Certs: crts, - Total: c.counter, - Offset: offset, - Limit: limit, - } - return page, nil +func (c *certsRepoMock) Update(ctx context.Context, oldSerial string, cert certs.Cert) error { + return nil } -func (c *certsRepoMock) RetrieveBySerial(ctx context.Context, ownerID, serialID string) (certs.Cert, error) { - c.mu.Lock() - defer c.mu.Unlock() - - crt, ok := c.certsBySerial[serialID] - if !ok { - return certs.Cert{}, errors.ErrNotFound - } +func (c *certsRepoMock) RetrieveThingCerts(ctx context.Context, thingID string) (certs.Page, error) { + return certs.Page{}, nil +} - return crt, nil +func (c *certsRepoMock) RemoveThingCerts(ctx context.Context, thingID string) error { + return nil } diff --git a/certs/mocks/pki.go b/certs/mocks/pki.go index f59ca327fe..f4e0c99f52 100644 --- a/certs/mocks/pki.go +++ b/certs/mocks/pki.go @@ -21,8 +21,6 @@ import ( "github.com/mainflux/mainflux/pkg/errors" ) -const keyBits = 2048 - var ( errPrivateKeyEmpty = errors.New("private key is empty") errPrivateKeyUnsupportedType = errors.New("private key type is unsupported") @@ -31,26 +29,26 @@ var ( var _ pki.Agent = (*agent)(nil) type agent struct { - AuthTimeout time.Duration - TLSCert tls.Certificate - X509Cert *x509.Certificate - TTL string - mu sync.Mutex - counter uint64 - certs map[string]pki.Cert + TLSCert tls.Certificate + X509Cert *x509.Certificate + RSABits int + TTL string + mu sync.Mutex + counter uint64 + certs map[string]pki.Cert } -func NewPkiAgent(tlsCert tls.Certificate, caCert *x509.Certificate, ttl string, timeout time.Duration) pki.Agent { +func NewPkiAgent(tlsCert tls.Certificate, caCert *x509.Certificate, keyBits int, ttl string) pki.Agent { return &agent{ - AuthTimeout: timeout, - TLSCert: tlsCert, - X509Cert: caCert, - TTL: ttl, - certs: make(map[string]pki.Cert), + TLSCert: tlsCert, + X509Cert: caCert, + RSABits: keyBits, + TTL: ttl, + certs: make(map[string]pki.Cert), } } -func (a *agent) IssueCert(cn, ttl string) (pki.Cert, error) { +func (a *agent) IssueCert(cn string, ttl string) (pki.Cert, error) { a.mu.Lock() defer a.mu.Unlock() @@ -59,7 +57,7 @@ func (a *agent) IssueCert(cn, ttl string) (pki.Cert, error) { } var priv interface{} - priv, err := rsa.GenerateKey(rand.Reader, keyBits) + priv, err := rsa.GenerateKey(rand.Reader, a.RSABits) if err != nil { return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err) } @@ -131,16 +129,16 @@ func (a *agent) IssueCert(cn, ttl string) (pki.Cert, error) { key := keyOut.String() a.certs[x509cert.SerialNumber.String()] = pki.Cert{ - ClientCert: cert, + Certificate: cert, } a.counter++ return pki.Cert{ - ClientCert: cert, - ClientKey: key, - Serial: x509cert.SerialNumber.String(), - Expire: x509cert.NotAfter.Unix(), - IssuingCA: x509cert.Issuer.String(), + Certificate: cert, + PrivateKey: key, + Serial: x509cert.SerialNumber.String(), + Expire: x509cert.NotAfter, + IssuingCA: x509cert.Issuer.String(), }, nil } diff --git a/certs/pki/vault.go b/certs/pki/vault.go index 9487a5f2bb..e62b415a85 100644 --- a/certs/pki/vault.go +++ b/certs/pki/vault.go @@ -5,7 +5,11 @@ package pki import ( + "crypto/x509" "encoding/json" + "encoding/pem" + "fmt" + "regexp" "time" "github.com/hashicorp/vault/api" @@ -29,23 +33,26 @@ var ( // ErrFailedCertRevocation indicates failed certificate revocation ErrFailedCertRevocation = errors.New("failed to revoke certificate") - errFailedCertDecoding = errors.New("failed to decode response from vault service") + errFailedVaultCertIssue = errors.New("failed to issue vault certificate") + errFailedVaultRead = errors.New("failed to read vault certificate") + errFailedCertDecoding = errors.New("failed to decode response from vault service") + expSerialNotFound = regexp.MustCompile(`Errors:\s*(.*?)\s*certificate with serial\s*(.*?)\s*not found`) ) type Cert struct { - ClientCert string `json:"client_cert" mapstructure:"certificate"` - IssuingCA string `json:"issuing_ca" mapstructure:"issuing_ca"` - CAChain []string `json:"ca_chain" mapstructure:"ca_chain"` - ClientKey string `json:"client_key" mapstructure:"private_key"` - PrivateKeyType string `json:"private_key_type" mapstructure:"private_key_type"` - Serial string `json:"serial" mapstructure:"serial_number"` - Expire int64 `json:"expire" mapstructure:"expiration"` + Certificate string `json:"certificate" mapstructure:"certificate"` + IssuingCA string `json:"issuing_ca" mapstructure:"issuing_ca"` + CAChain []string `json:"ca_chain" mapstructure:"ca_chain"` + PrivateKey string `json:"private_key" mapstructure:"private_key"` + PrivateKeyType string `json:"private_key_type" mapstructure:"private_key_type"` + Serial string `json:"serial" mapstructure:"serial_number"` + Expire time.Time `json:"expire" mapstructure:"-"` } // Agent represents the Vault PKI interface. type Agent interface { // IssueCert issues certificate on PKI - IssueCert(cn, ttl string) (Cert, error) + IssueCert(cn string, ttl string) (Cert, error) // Read retrieves certificate from PKI Read(serial string) (Cert, error) @@ -97,7 +104,7 @@ func NewVaultClient(token, host, path, role string) (Agent, error) { return &p, nil } -func (p *pkiAgent) IssueCert(cn, ttl string) (Cert, error) { +func (p *pkiAgent) IssueCert(cn string, ttl string) (Cert, error) { cReq := certReq{ CommonName: cn, TTL: ttl, @@ -114,13 +121,18 @@ func (p *pkiAgent) IssueCert(cn, ttl string) (Cert, error) { s, err := p.client.Logical().Write(p.issueURL, certIssueReq) if err != nil { - return Cert{}, err + return Cert{}, errors.Wrap(errFailedVaultCertIssue, err) } cert := Cert{} if err = mapstructure.Decode(s.Data, &cert); err != nil { return Cert{}, errors.Wrap(errFailedCertDecoding, err) } + pubCert, err := p.parseCert(cert.Certificate) + if err != nil { + return Cert{}, errors.Wrap(errFailedCertDecoding, err) + } + cert.Expire = pubCert.NotAfter return cert, nil } @@ -128,7 +140,7 @@ func (p *pkiAgent) IssueCert(cn, ttl string) (Cert, error) { func (p *pkiAgent) Read(serial string) (Cert, error) { s, err := p.client.Logical().Read(p.readURL + serial) if err != nil { - return Cert{}, err + return Cert{}, errors.Wrap(errFailedVaultRead, err) } cert := Cert{} if err = mapstructure.Decode(s.Data, &cert); err != nil { @@ -145,6 +157,9 @@ func (p *pkiAgent) Revoke(serial string) (time.Time, error) { var certRevokeReq map[string]interface{} data, err := json.Marshal(cReq) if err != nil { + if expSerialNotFound.Match([]byte(err.Error())) { + return time.Time{}, errors.ErrNotFound + } return time.Time{}, err } if err := json.Unmarshal(data, &certRevokeReq); err != nil { @@ -162,3 +177,11 @@ func (p *pkiAgent) Revoke(serial string) (time.Time, error) { return time.Unix(0, int64(rev)*int64(time.Second)), nil } + +func (c *pkiAgent) parseCert(data string) (*x509.Certificate, error) { + block, _ := pem.Decode([]byte(data)) + if block == nil { + return nil, fmt.Errorf("failed to decode client certificate") + } + return x509.ParseCertificate(block.Bytes) +} diff --git a/certs/postgres/certs.go b/certs/postgres/certs.go index d91bfd6532..b164a7c448 100644 --- a/certs/postgres/certs.go +++ b/certs/postgres/certs.go @@ -7,17 +7,18 @@ import ( "context" "database/sql" "fmt" + "strings" "time" - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5/pgconn" - - "github.com/jmoiron/sqlx" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access "github.com/mainflux/mainflux/certs" - "github.com/mainflux/mainflux/logger" + pgClient "github.com/mainflux/mainflux/internal/clients/postgres" + "github.com/mainflux/mainflux/internal/sqlxt" "github.com/mainflux/mainflux/pkg/errors" ) +var errInvalidRevocationTime = errors.New("invalid revocation time") + var _ certs.Repository = (*certsRepository)(nil) // Cert holds info on expiration date for specific cert issued for specific Thing. @@ -28,40 +29,117 @@ type Cert struct { } type certsRepository struct { - db *sqlx.DB - log logger.Logger + db sqlxt.Database } // NewRepository instantiates a PostgreSQL implementation of certs // repository. -func NewRepository(db *sqlx.DB, log logger.Logger) certs.Repository { - return &certsRepository{db: db, log: log} +func NewRepository(db sqlxt.Database) certs.Repository { + return &certsRepository{ + db: db, + } } -func (cr certsRepository) RetrieveAll(ctx context.Context, ownerID string, offset, limit uint64) (certs.Page, error) { - q := `SELECT thing_id, owner_id, serial, expire FROM certs WHERE owner_id = $1 ORDER BY expire LIMIT $2 OFFSET $3;` - rows, err := cr.db.Query(q, ownerID, limit, offset) +func (cr certsRepository) Save(ctx context.Context, cert certs.Cert) error { + + q := `INSERT INTO certs + (id, name, owner_id, thing_id, serial, private_key, certificate, ca_chain, issuing_ca, ttl, expire) + VALUES + (:id, :name, :owner_id, :thing_id, :serial, :private_key, :certificate, :ca_chain, :issuing_ca, :ttl, :expire) + ` + dbc, err := CertToDbCert(cert) if err != nil { - cr.log.Error(fmt.Sprintf("Failed to retrieve configs due to %s", err)) - return certs.Page{}, err + return err + } + if _, err, txErr := cr.db.NamedCUDContext(ctx, q, dbc); err != nil || txErr != nil { + err = pgClient.CheckError(err, pgClient.Create) + return errors.Wrap(err, txErr) + } + return nil +} + +func (cr certsRepository) Update(ctx context.Context, certID string, cert certs.Cert) error { + q := ` + UPDATE + certs + SET + serial = :serial, + private_key = :private_key, + certificate = :certificate, + ca_chain = :ca_chain, + issuing_ca = :issuing_ca, + expire = :expire, + revocation = :revocation + WHERE id = :id AND owner_id = :owner_id + ` + dbc, err := CertToDbCert(cert) + if err != nil { + return err + } + if _, err, txErr := cr.db.NamedCUDContext(ctx, q, dbc); err != nil || txErr != nil { + err = pgClient.CheckError(err, pgClient.Update) + return errors.Wrap(err, txErr) + } + return nil +} + +func (cr certsRepository) Remove(ctx context.Context, ownerID, certID string) error { + q := `DELETE FROM certs WHERE id = :id` + + dbc, err := CertToDbCert(certs.Cert{ID: certID}) + if err != nil { + return err + } + if _, err, txErr := cr.db.NamedCUDContext(ctx, q, dbc); err != nil || txErr != nil { + err = pgClient.CheckError(err, pgClient.Remove) + return errors.Wrap(err, txErr) + } + return nil +} + +func (cr certsRepository) Retrieve(ctx context.Context, ownerID, certID, thingID, serial, name string, status certs.Status, offset uint64, limit int64) (certs.Page, error) { + q := ` + SELECT + id, name, owner_id, thing_id, serial, private_key, certificate, ca_chain, issuing_ca, ttl, expire, revocation + FROM + certs + WHERE owner_id = :owner_id + %s + ORDER BY expire %s; + ` + + q = fmt.Sprintf(q, whereClause(certID, thingID, serial, name, status), orderClause(limit)) + + params := map[string]interface{}{ + "limit": limit, + "offset": offset, + "owner_id": ownerID, + "id": certID, + "thing_id": thingID, + "serial": serial, + "name": name, + } + + rows, err := cr.db.NamedQueryContext(ctx, q, params) + if err != nil { + return certs.Page{}, pgClient.CheckError(err, pgClient.View) } defer rows.Close() certificates := []certs.Cert{} for rows.Next() { - c := certs.Cert{} - if err := rows.Scan(&c.ThingID, &c.OwnerID, &c.Serial, &c.Expire); err != nil { - cr.log.Error(fmt.Sprintf("Failed to read retrieved config due to %s", err)) - return certs.Page{}, err - + dbcs := dbCert{} + if err := rows.StructScan(&dbcs); err != nil { + return certs.Page{}, pgClient.CheckError(err, pgClient.View) } - certificates = append(certificates, c) + certificates = append(certificates, dbcs.ToCert()) + } + if len(certificates) < 1 { + return certs.Page{}, errors.ErrNotFound } - q = `SELECT COUNT(*) FROM certs WHERE owner_id = $1` - var total uint64 - if err := cr.db.QueryRow(q, ownerID).Scan(&total); err != nil { - cr.log.Error(fmt.Sprintf("Failed to count certs due to %s", err)) + total, err := cr.RetrieveCount(ctx, ownerID, certID, thingID, serial, name, status) + if err != nil { return certs.Page{}, err } @@ -73,131 +151,200 @@ func (cr certsRepository) RetrieveAll(ctx context.Context, ownerID string, offse }, nil } -func (cr certsRepository) Save(ctx context.Context, cert certs.Cert) (string, error) { - q := `INSERT INTO certs (thing_id, owner_id, serial, expire) VALUES (:thing_id, :owner_id, :serial, :expire)` - - tx, err := cr.db.Beginx() - if err != nil { - return "", errors.Wrap(errors.ErrCreateEntity, err) +func (cr certsRepository) RetrieveCount(ctx context.Context, ownerID, certID, thingID, serial, name string, status certs.Status) (uint64, error) { + qc := ` + SELECT + COUNT(*) + FROM + certs + WHERE owner_id = :owner_id + %s + ; + ` + params := map[string]interface{}{ + "owner_id": ownerID, + "id": certID, + "thing_id": thingID, + "serial": serial, + "name": name, } - - dbcrt := toDBCert(cert) - - if _, err := tx.NamedExec(q, dbcrt); err != nil { - e := err - if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == pgerrcode.UniqueViolation { - e = errors.New("error conflict") - } - - cr.rollback("Failed to insert a Cert", tx, err) - - return "", errors.Wrap(errors.ErrCreateEntity, e) - } - - if err := tx.Commit(); err != nil { - cr.rollback("Failed to commit Config save", tx, err) + qc = fmt.Sprintf(qc, whereClause(certID, thingID, serial, name, status)) + total, err := cr.db.NamedTotalQueryContext(ctx, qc, params) + if err != nil { + return 0, pgClient.CheckError(err, pgClient.View) } - - return cert.Serial, nil + return total, nil } -func (cr certsRepository) Remove(ctx context.Context, ownerID, serial string) error { - if _, err := cr.RetrieveBySerial(ctx, ownerID, serial); err != nil { - return errors.Wrap(errors.ErrRemoveEntity, err) - } - q := `DELETE FROM certs WHERE serial = :serial` - var c certs.Cert - c.Serial = serial - dbcrt := toDBCert(c) - if _, err := cr.db.NamedExecContext(ctx, q, dbcrt); err != nil { - return errors.Wrap(errors.ErrRemoveEntity, err) - } - return nil -} +func (cr certsRepository) RetrieveThingCerts(ctx context.Context, thingID string) (certs.Page, error) { + q := ` + SELECT + id, name, owner_id, thing_id, serial, private_key, certificate, ca_chain, issuing_ca, ttl, expire, revocation + FROM + certs + WHERE thing_id = :thing_id + ORDER BY expire; + ` -func (cr certsRepository) RetrieveByThing(ctx context.Context, ownerID, thingID string, offset, limit uint64) (certs.Page, error) { - q := `SELECT thing_id, owner_id, serial, expire FROM certs WHERE owner_id = $1 AND thing_id = $2 ORDER BY expire LIMIT $3 OFFSET $4;` - rows, err := cr.db.Query(q, ownerID, thingID, limit, offset) + params := certs.Cert{ThingID: thingID} + + rows, err := cr.db.NamedQueryContext(ctx, q, params) if err != nil { - cr.log.Error(fmt.Sprintf("Failed to retrieve configs due to %s", err)) - return certs.Page{}, err + return certs.Page{}, pgClient.CheckError(err, pgClient.View) } defer rows.Close() certificates := []certs.Cert{} for rows.Next() { - c := certs.Cert{} - if err := rows.Scan(&c.ThingID, &c.OwnerID, &c.Serial, &c.Expire); err != nil { - cr.log.Error(fmt.Sprintf("Failed to read retrieved config due to %s", err)) - return certs.Page{}, err - + dbcs := dbCert{} + if err := rows.StructScan(&dbcs); err != nil { + return certs.Page{}, pgClient.CheckError(err, pgClient.View) } - certificates = append(certificates, c) + certificates = append(certificates, dbcs.ToCert()) } - q = `SELECT COUNT(*) FROM certs WHERE owner_id = $1 AND thing_id = $2` - var total uint64 - if err := cr.db.QueryRow(q, ownerID, thingID).Scan(&total); err != nil { - cr.log.Error(fmt.Sprintf("Failed to count certs due to %s", err)) - return certs.Page{}, err + qc := ` + SELECT + COUNT(*) + FROM + certs + WHERE thing_id = :thing_id + ` + total, err := cr.db.NamedTotalQueryContext(ctx, qc, params) + if err != nil { + return certs.Page{}, pgClient.CheckError(err, pgClient.View) } return certs.Page{ Total: total, - Limit: limit, - Offset: offset, + Limit: 0, + Offset: 0, Certs: certificates, }, nil } -func (cr certsRepository) RetrieveBySerial(ctx context.Context, ownerID, serialID string) (certs.Cert, error) { - q := `SELECT thing_id, owner_id, serial, expire FROM certs WHERE owner_id = $1 AND serial = $2` - var dbcrt dbCert - var c certs.Cert +func (cr certsRepository) RemoveThingCerts(ctx context.Context, thingID string) error { + q := `DELETE FROM certs WHERE thing_id = thingID` + dbc, err := CertToDbCert(certs.Cert{ThingID: thingID}) + if err != nil { + return err + } + if _, err, txErr := cr.db.NamedCUDContext(ctx, q, dbc); err != nil || txErr != nil { + err = pgClient.CheckError(err, pgClient.Remove) + return errors.Wrap(err, txErr) + } + return nil +} - if err := cr.db.QueryRowxContext(ctx, q, ownerID, serialID).StructScan(&dbcrt); err != nil { +type dbCert struct { + ID string `db:"id"` + Name string `db:"name"` + OwnerID string `db:"owner_id"` + ThingID string `db:"thing_id"` + Serial string `db:"serial"` + Certificate string `db:"certificate"` + PrivateKey string `db:"private_key"` + CAChain string `db:"ca_chain"` + IssuingCA string `db:"issuing_ca"` + TTL string `db:"ttl"` + Expire time.Time `db:"expire"` + Revocation sql.NullTime `db:"revocation"` +} - pqErr, ok := err.(*pgconn.PgError) - if err == sql.ErrNoRows || ok && pgerrcode.InvalidTextRepresentation == pqErr.Code { - return c, errors.Wrap(errors.ErrNotFound, err) +func (c *dbCert) ToCert() certs.Cert { + var rev time.Time + if c.Revocation.Valid { + rev = c.Revocation.Time + } + return certs.Cert{ + ID: c.ID, + Name: c.Name, + OwnerID: c.OwnerID, + ThingID: c.ThingID, + Serial: c.Serial, + Certificate: c.Certificate, + PrivateKey: c.PrivateKey, + CAChain: c.CAChain, + IssuingCA: c.IssuingCA, + TTL: c.TTL, + Expire: c.Expire, + Revocation: rev, + } +} + +func CertToDbCert(c certs.Cert) (dbCert, error) { + var revokeTime sql.NullTime + if !c.Revocation.IsZero() { + if err := revokeTime.Scan(c.Revocation); err != nil { + return dbCert{}, errors.Wrap(errInvalidRevocationTime, err) } + } + fmt.Println(revokeTime) + return dbCert{ + ID: c.ID, + Name: c.Name, + OwnerID: c.OwnerID, + ThingID: c.ThingID, + Serial: c.Serial, + Certificate: c.Certificate, + PrivateKey: c.PrivateKey, + CAChain: c.CAChain, + IssuingCA: c.IssuingCA, + TTL: c.TTL, + Expire: c.Expire, + Revocation: revokeTime, + }, nil +} - return c, errors.Wrap(errors.ErrViewEntity, err) +func whereClause(certID, thingID, serial, name string, status certs.Status) string { + var clause []string + if certID != "" { + clause = append(clause, " id = :id ") } - c = toCert(dbcrt) - return c, nil -} + if thingID != "" { + clause = append(clause, " thing_id = :thing_id ") + } -func (cr certsRepository) rollback(content string, tx *sqlx.Tx, err error) { - cr.log.Error(fmt.Sprintf("%s %s", content, err)) + if serial != "" { + clause = append(clause, " serial = :serial ") + } - if err := tx.Rollback(); err != nil { - cr.log.Error(fmt.Sprintf("Failed to rollback due to %s", err)) + if name != "" { + clause = append(clause, " name = :name ") } -} -type dbCert struct { - ThingID string `db:"thing_id"` - Serial string `db:"serial"` - Expire time.Time `db:"expire"` - OwnerID string `db:"owner_id"` + if sf := statusFilter(status); sf != "" { + clause = append(clause, sf) + } + + c := strings.Join(clause, " AND ") + if c != "" { + c = " AND " + c + } + return c } -func toDBCert(c certs.Cert) dbCert { - return dbCert{ - ThingID: c.ThingID, - OwnerID: c.OwnerID, - Serial: c.Serial, - Expire: c.Expire, +func orderClause(limit int64) string { + var clause []string + if limit >= 0 { + clause = append(clause, " LIMIT :limit ") } + clause = append(clause, " OFFSET :offset ") + return strings.Join(clause, " ") } -func toCert(cdb dbCert) certs.Cert { - var c certs.Cert - c.OwnerID = cdb.OwnerID - c.ThingID = cdb.ThingID - c.Serial = cdb.Serial - c.Expire = cdb.Expire - return c +func statusFilter(status certs.Status) string { + var filterQuery string + switch status { + case certs.ActiveCerts: + filterQuery = "revocation is NULL" + case certs.RevokedCerts: + filterQuery = "revocation is NOT NULL" + case certs.AllCerts: + fallthrough + default: + filterQuery = "" + } + return filterQuery } diff --git a/certs/postgres/init.go b/certs/postgres/init.go index 04aa8c0306..1271f37881 100644 --- a/certs/postgres/init.go +++ b/certs/postgres/init.go @@ -26,6 +26,38 @@ func Migration() *migrate.MemoryMigrationSource { "DROP TABLE IF EXISTS certs;", }, }, + + { + Id: "certs_2", + Up: []string{ + ` + ALTER TABLE certs DROP CONSTRAINT certs_pkey; + ALTER TABLE certs ADD COLUMN id UUID NOT NULL; + ALTER TABLE certs ADD COLUMN name VARCHAR(254) NOT NULL; + ALTER TABLE certs ADD COLUMN certificate TEXT NOT NULL; + ALTER TABLE certs ADD COLUMN private_key TEXT NOT NULL; + ALTER TABLE certs ADD COLUMN ca_chain TEXT NOT NULL; + ALTER TABLE certs ADD COLUMN issuing_ca TEXT NOT NULL; + ALTER TABLE certs ADD COLUMN ttl VARCHAR(254) NOT NULL; + ALTER TABLE certs ADD COLUMN revocation TIMESTAMPTZ NULL; + ALTER TABLE certs ADD PRIMARY KEY (name, thing_id, owner_id); + `, + }, + Down: []string{ + ` + ALTER TABLE certs DROP CONSTRAINT certs_pkey; + ALTER TABLE certs DROP COLUMN id data_type; + ALTER TABLE certs DROP COLUMN name data_type; + ALTER TABLE certs DROP COLUMN certificate data_type; + ALTER TABLE certs DROP COLUMN private_key data_type; + ALTER TABLE certs DROP COLUMN ca_chain data_type; + ALTER TABLE certs DROP COLUMN issuing_ca data_type; + ALTER TABLE certs DROP COLUMN ttl data_type; + ALTER TABLE certs DROP COLUMN revocation data_type; + ALTER TABLE certs ADD PRIMARY KEY (thing_id, owner_id, serial); + `, + }, + }, }, } } diff --git a/certs/redis/consumer/doc.go b/certs/redis/consumer/doc.go new file mode 100644 index 0000000000..37dc99411b --- /dev/null +++ b/certs/redis/consumer/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Mainflux +// SPDX-License-Identifier: Apache-2.0 + +// Package consumer contains events consumer for events +// published by Things service. +package consumer diff --git a/certs/redis/consumer/event.go b/certs/redis/consumer/event.go new file mode 100644 index 0000000000..6b19d2111e --- /dev/null +++ b/certs/redis/consumer/event.go @@ -0,0 +1,4 @@ +// Copyright (c) Mainflux +// SPDX-License-Identifier: Apache-2.0 + +package consumer diff --git a/certs/redis/consumer/stream.go b/certs/redis/consumer/stream.go new file mode 100644 index 0000000000..28d24e52ea --- /dev/null +++ b/certs/redis/consumer/stream.go @@ -0,0 +1,80 @@ +// Copyright (c) Mainflux +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "context" + + "github.com/go-redis/redis/v8" + "github.com/mainflux/mainflux/certs" + "github.com/mainflux/mainflux/logger" +) + +const ( + stream = "mainflux.things" + group = "mainflux.certs" + + thingPrefix = "thing." + thingRemove = thingPrefix + "remove" + + exists = "BUSYGROUP Consumer Group name already exists" +) + +// Subscriber represents event source for things and channels provisioning. +type Subscriber interface { + // Subscribes to given subject and receives events. + Subscribe(context.Context, string) error +} + +type eventStore struct { + svc certs.Service + client *redis.Client + consumer string + logger logger.Logger +} + +// NewEventStore returns new event store instance. +func NewEventStore(svc certs.Service, client *redis.Client, consumer string, log logger.Logger) Subscriber { + return eventStore{ + svc: svc, + client: client, + consumer: consumer, + logger: log, + } +} + +func (es eventStore) Subscribe(ctx context.Context, subject string) error { + err := es.client.XGroupCreateMkStream(ctx, stream, group, "$").Err() + if err != nil && err.Error() != exists { + return err + } + + for { + streams, err := es.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: group, + Consumer: es.consumer, + Streams: []string{stream, ">"}, + Count: 100, + }).Result() + if err != nil || len(streams) == 0 { + continue + } + + for _, msg := range streams[0].Messages { + event := msg.Values + + switch event["operation"] { + case thingRemove: + // rte := decodeRemoveThing(event) + // err := make(chan error) + // go es.svc.EventHandlerDeleteThing(ctx, rte.id, err) + // for e := range err { + // es.logger.Info(fmt.Sprintf("Error on thing remove event handled , Thing ID %s error : %v", rte.id, e)) + // } + } + + es.client.XAck(ctx, stream, group, msg.ID) + } + } +} diff --git a/certs/service.go b/certs/service.go index deabf727fc..8bd494fbce 100644 --- a/certs/service.go +++ b/certs/service.go @@ -5,6 +5,9 @@ package certs import ( "context" + "crypto/x509" + "encoding/pem" + "strings" "time" "github.com/mainflux/mainflux" @@ -13,14 +16,25 @@ import ( mfsdk "github.com/mainflux/mainflux/pkg/sdk/go" ) +// Key types and format : https://developer.hashicorp.com/vault/api-docs/secret/pki#key_type +const ( + caChainJoinSep = "\n\n" +) + var ( - // ErrFailedCertCreation failed to create certificate - ErrFailedCertCreation = errors.New("failed to create client certificate") + ErrThingRetrieve = errors.New("failed to retrieve thing details") + + ErrPKIIssue = errors.New("failed to issue certificate in PKI") + + errPKIRevoke = errors.New("failed to revoke certificate in PKI") + + errRepoRetrieve = errors.New("failed to retrieve certificate from repo") + + errRepoUpdate = errors.New("failed to update certificate from repo") - // ErrFailedCertRevocation failed to revoke certificate - ErrFailedCertRevocation = errors.New("failed to revoke certificate") + errRepoRemove = errors.New("failed to remove the certificate from db") - errFailedToRemoveCertFromDB = errors.New("failed to remove cert serial from db") + errParseCert = errors.New("failed to parse the certificate, invalid certificate") ) var _ Service = (*certsService)(nil) @@ -29,35 +43,49 @@ var _ Service = (*certsService)(nil) // implementation, and all of its decorators (e.g. logging & metrics). type Service interface { // IssueCert issues certificate for given thing id if access is granted with token - IssueCert(ctx context.Context, token, thingID, ttl string) (Cert, error) + IssueCert(ctx context.Context, token, thingID, name, ttl string) (Cert, error) - // ListCerts lists certificates issued for a given thing ID - ListCerts(ctx context.Context, token, thingID string, offset, limit uint64) (Page, error) + // ViewCert retrieves the certificate issued for a given certificate ID + ViewCert(ctx context.Context, token, certID string) (Cert, error) - // ListSerials lists certificate serial IDs issued for a given thing ID - ListSerials(ctx context.Context, token, thingID string, offset, limit uint64) (Page, error) + // RenewCert the expired certificate from certs repo + RenewCert(ctx context.Context, token, certID string) (Cert, error) - // ViewCert retrieves the certificate issued for a given serial ID - ViewCert(ctx context.Context, token, serialID string) (Cert, error) + // RevokeCert revokes a certificate for a given certificate ID + RevokeCert(ctx context.Context, token, certID string) error - // RevokeCert revokes a certificate for a given serial ID - RevokeCert(ctx context.Context, token, serialID string) (Revoke, error) + // RemoveCert revoke and delete entry the certificate for a given certificate ID + RemoveCert(ctx context.Context, token, certID string) error + + // ListCerts lists certificates issued for a given certificate ID + ListCerts(ctx context.Context, token, certID, thingID, serial, name string, status Status, offset, limit uint64) (Page, error) + + // RevokeThingCerts revokes a all the certificates for a given thing ID with given limited count + RevokeThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) + + // RenewThingCerts renew all the certificates for a given thing ID with given limited count + RenewThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) + + // RemoveThingCerts revoke and delete entries of all the certificate for a given thing ID with given limited count + RemoveThingCerts(ctx context.Context, token, certID string, limit int64) (uint64, error) } type certsService struct { - auth mainflux.AuthServiceClient - certsRepo Repository - sdk mfsdk.SDK - pki pki.Agent + auth mainflux.AuthServiceClient + idProvider mainflux.IDProvider + repo Repository + sdk mfsdk.SDK + pki pki.Agent } -// New returns new Certs service -func New(auth mainflux.AuthServiceClient, certs Repository, sdk mfsdk.SDK, pki pki.Agent) Service { +// New returns new Certs service. +func New(auth mainflux.AuthServiceClient, repo Repository, idp mainflux.IDProvider, pki pki.Agent, sdk mfsdk.SDK) Service { return &certsService{ - certsRepo: certs, - sdk: sdk, - auth: auth, - pki: pki, + repo: repo, + idProvider: idp, + auth: auth, + pki: pki, + sdk: sdk, } } @@ -68,18 +96,21 @@ type Revoke struct { // Cert defines the certificate paremeters type Cert struct { - OwnerID string `json:"owner_id" mapstructure:"owner_id"` - ThingID string `json:"thing_id" mapstructure:"thing_id"` - ClientCert string `json:"client_cert" mapstructure:"certificate"` - IssuingCA string `json:"issuing_ca" mapstructure:"issuing_ca"` - CAChain []string `json:"ca_chain" mapstructure:"ca_chain"` - ClientKey string `json:"client_key" mapstructure:"private_key"` - PrivateKeyType string `json:"private_key_type" mapstructure:"private_key_type"` - Serial string `json:"serial" mapstructure:"serial_number"` - Expire time.Time `json:"expire" mapstructure:"-"` -} - -func (cs *certsService) IssueCert(ctx context.Context, token, thingID string, ttl string) (Cert, error) { + ID string `json:"id" db:"id"` + Name string `json:"name" db:"name"` + OwnerID string `json:"owner_id" db:"owner_id"` + ThingID string `json:"thing_id" db:"thing_id"` + Serial string `json:"serial" db:"serial"` + Certificate string `json:"certificate" db:"certificate"` + PrivateKey string `json:"private_key" db:"private_key"` + CAChain string `json:"ca_chain" db:"ca_chain"` + IssuingCA string `json:"issuing_ca" db:"issuing_ca"` + TTL string `json:"ttl" db:"ttl"` + Expire time.Time `json:"expire" db:"expire"` + Revocation time.Time `json:"revocation" db:"revocation"` +} + +func (cs *certsService) IssueCert(ctx context.Context, token, thingID string, name string, ttl string) (Cert, error) { owner, err := cs.auth.Identify(ctx, &mainflux.Token{Value: token}) if err != nil { return Cert{}, err @@ -87,116 +118,236 @@ func (cs *certsService) IssueCert(ctx context.Context, token, thingID string, tt thing, err := cs.sdk.Thing(thingID, token) if err != nil { - return Cert{}, errors.Wrap(ErrFailedCertCreation, err) + return Cert{}, errors.Wrap(ErrThingRetrieve, err) + } + + id, err := cs.idProvider.ID() + if err != nil { + return Cert{}, err } cert, err := cs.pki.IssueCert(thing.Key, ttl) if err != nil { - return Cert{}, errors.Wrap(ErrFailedCertCreation, err) + return Cert{}, errors.Wrap(ErrPKIIssue, err) } c := Cert{ - ThingID: thingID, - OwnerID: owner.GetId(), - ClientCert: cert.ClientCert, - IssuingCA: cert.IssuingCA, - CAChain: cert.CAChain, - ClientKey: cert.ClientKey, - PrivateKeyType: cert.PrivateKeyType, - Serial: cert.Serial, - Expire: time.Unix(0, int64(cert.Expire)*int64(time.Second)), + ID: id, + Name: name, + ThingID: thingID, + OwnerID: owner.GetId(), + Certificate: cert.Certificate, + IssuingCA: cert.IssuingCA, + CAChain: strings.Join(cert.CAChain, caChainJoinSep), + PrivateKey: cert.PrivateKey, + Serial: cert.Serial, + TTL: ttl, + Expire: cert.Expire, + } + + err = cs.repo.Save(context.Background(), c) + if err != nil { + return Cert{}, err } + return c, nil +} - _, err = cs.certsRepo.Save(context.Background(), c) - return c, err +func (cs *certsService) ListCerts(ctx context.Context, token, certID, thingID, serial, name string, status Status, offset, limit uint64) (Page, error) { + p, _, err := cs.identifyAndRetrieve(ctx, token, certID, thingID, serial, name, status, offset, int64(limit)) + return p, err } -func (cs *certsService) RevokeCert(ctx context.Context, token, thingID string) (Revoke, error) { - var revoke Revoke - u, err := cs.auth.Identify(ctx, &mainflux.Token{Value: token}) +func (cs *certsService) ViewCert(ctx context.Context, token, certID string) (Cert, error) { + cp, u, err := cs.identifyAndRetrieve(ctx, token, certID, "", "", "", AllCerts, 0, 1) + if err != nil { + return Cert{}, err + } + if len(cp.Certs) < 1 { + return Cert{}, errors.ErrNotFound + } + + cert := cp.Certs[0] + if time.Until(cert.Expire) < time.Duration(1*time.Hour) { + cert, err = cs.renewAndUpdate(ctx, u.GetId(), cert) + if err != nil { + return Cert{}, err + } + } + return cert, nil +} + +func (cs *certsService) RenewCert(ctx context.Context, token, certID string) (Cert, error) { + cp, u, err := cs.identifyAndRetrieve(ctx, token, certID, "", "", "", AllCerts, 0, 1) if err != nil { - return revoke, err + return Cert{}, err } - thing, err := cs.sdk.Thing(thingID, token) + + // ToDo don't renew before revoke , To check revoke is zero logic should be time.Now().Sub(revokeTime) != time.Now() + return cs.renewAndUpdate(ctx, u.GetId(), cp.Certs[0]) +} + +func (cs *certsService) RevokeCert(ctx context.Context, token, certID string) error { + cp, u, err := cs.identifyAndRetrieve(ctx, token, certID, "", "", "", AllCerts, 0, 1) if err != nil { - return revoke, errors.Wrap(ErrFailedCertRevocation, err) + return err } - // TODO: Replace offset and limit - offset, limit := uint64(0), uint64(10000) - cp, err := cs.certsRepo.RetrieveByThing(ctx, u.GetId(), thing.ID, offset, limit) + return cs.revokeAndUpdate(ctx, u.GetId(), cp.Certs[0]) +} + +func (cs *certsService) RemoveCert(ctx context.Context, token, certID string) error { + cp, u, err := cs.identifyAndRetrieve(ctx, token, certID, "", "", "", AllCerts, 0, 1) if err != nil { - return revoke, errors.Wrap(ErrFailedCertRevocation, err) + return err } + return cs.revokeAndRemove(ctx, u.GetId(), cp.Certs[0]) +} - for _, c := range cp.Certs { - revTime, err := cs.pki.Revoke(c.Serial) - if err != nil { - return revoke, errors.Wrap(ErrFailedCertRevocation, err) +func (cs *certsService) RenewThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) { + cp, u, err := cs.identifyAndRetrieve(ctx, token, "", thingID, "", "", RevokedCerts, 0, limit) + if err != nil { + if errors.Contains(err, errors.ErrNotFound) { + return 0, nil } - revoke.RevocationTime = revTime - if err = cs.certsRepo.Remove(context.Background(), u.GetId(), c.Serial); err != nil { - return revoke, errors.Wrap(errFailedToRemoveCertFromDB, err) + return 0, err + } + + for _, cert := range cp.Certs { + // ToDo don't renew before revoke , To check revoke is zero logic should be time.Now().Sub(revokeTime) != time.Now() + _, err := cs.renewAndUpdate(ctx, u.GetId(), cert) + if err != nil { + return 0, err } } + c, err := cs.repo.RetrieveCount(ctx, u.GetId(), "", thingID, "", "", RevokedCerts) + if err != nil { + return 0, err + } - return revoke, nil + return c, nil } -func (cs *certsService) ListCerts(ctx context.Context, token, thingID string, offset, limit uint64) (Page, error) { - u, err := cs.auth.Identify(ctx, &mainflux.Token{Value: token}) +func (cs *certsService) RevokeThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) { + cp, u, err := cs.identifyAndRetrieve(ctx, token, "", thingID, "", "", ActiveCerts, 0, limit) + if err != nil { + if errors.Contains(err, errors.ErrNotFound) { + return 0, nil + } + return 0, err + } + + for _, cert := range cp.Certs { + err := cs.revokeAndUpdate(ctx, u.GetId(), cert) + if err != nil { + return 0, err + } + } + + c, err := cs.repo.RetrieveCount(ctx, u.GetId(), "", thingID, "", "", ActiveCerts) if err != nil { - return Page{}, err + return 0, err } + return c, nil +} - cp, err := cs.certsRepo.RetrieveByThing(ctx, u.GetId(), thingID, offset, limit) +func (cs *certsService) RemoveThingCerts(ctx context.Context, token, thingID string, limit int64) (uint64, error) { + cp, u, err := cs.identifyAndRetrieve(ctx, token, "", thingID, "", "", AllCerts, 0, limit) if err != nil { - return Page{}, err + return 0, err } - for i, cert := range cp.Certs { - vcert, err := cs.pki.Read(cert.Serial) + for _, cert := range cp.Certs { + err := cs.revokeAndRemove(ctx, u.GetId(), cert) if err != nil { - return Page{}, err + return 0, err } - cp.Certs[i].ClientCert = vcert.ClientCert - cp.Certs[i].ClientKey = vcert.ClientKey } - return cp, nil -} - -func (cs *certsService) ListSerials(ctx context.Context, token, thingID string, offset, limit uint64) (Page, error) { - u, err := cs.auth.Identify(ctx, &mainflux.Token{Value: token}) + c, err := cs.repo.RetrieveCount(ctx, u.GetId(), "", thingID, "", "", AllCerts) if err != nil { - return Page{}, err + return 0, err } - return cs.certsRepo.RetrieveByThing(ctx, u.GetId(), thingID, offset, limit) + return c, nil } -func (cs *certsService) ViewCert(ctx context.Context, token, serialID string) (Cert, error) { +func (cs *certsService) identifyAndRetrieve(ctx context.Context, token, certID, thingID, serial, name string, status Status, offset uint64, limit int64) (Page, *mainflux.UserIdentity, error) { u, err := cs.auth.Identify(ctx, &mainflux.Token{Value: token}) if err != nil { - return Cert{}, err + return Page{}, u, errors.Wrap(errors.ErrAuthentication, err) } + cp, err := cs.repo.Retrieve(ctx, u.GetId(), certID, thingID, serial, name, status, offset, limit) - cert, err := cs.certsRepo.RetrieveBySerial(ctx, u.GetId(), serialID) if err != nil { - return Cert{}, err + return Page{}, u, errors.Wrap(errRepoRetrieve, err) } + return cp, u, nil +} - vcert, err := cs.pki.Read(serialID) +func (cs *certsService) renewAndUpdate(ctx context.Context, ownerID string, cert Cert) (Cert, error) { + xCert, err := parseCert(cert.Certificate) if err != nil { - return Cert{}, err + return Cert{}, errors.Wrap(errParseCert, err) + } + pkiCert, err := cs.pki.IssueCert(xCert.Subject.CommonName, cert.TTL) + if err != nil { + return Cert{}, errors.Wrap(ErrPKIIssue, err) } - c := Cert{ - ThingID: cert.ThingID, - ClientCert: vcert.ClientCert, - Serial: cert.Serial, - Expire: cert.Expire, + cert.CAChain = strings.Join(pkiCert.CAChain, caChainJoinSep) + cert.Certificate = pkiCert.Certificate + cert.Expire = pkiCert.Expire + cert.IssuingCA = pkiCert.IssuingCA + cert.PrivateKey = pkiCert.PrivateKey + cert.Serial = pkiCert.Serial + cert.Revocation = time.Time{} + + if err = cs.repo.Update(context.Background(), ownerID, cert); err != nil { + return Cert{}, errors.Wrap(errRepoUpdate, err) } + return cert, nil +} - return c, nil +func (cs *certsService) revokeAndUpdate(ctx context.Context, ownerID string, c Cert) error { + if c.Revocation.IsZero() { + revTime, err := cs.pki.Revoke(c.Serial) + if err != nil { + return errors.Wrap(errPKIRevoke, err) + } + + c.Revocation = revTime + if err = cs.repo.Update(context.Background(), ownerID, c); err != nil { + return errors.Wrap(errRepoUpdate, err) + } + } + + return nil +} + +func (cs *certsService) revokeAndRemove(ctx context.Context, ownerID string, c Cert) error { + if time.Until(c.Revocation) < 0 { + revTime, err := cs.pki.Revoke(c.Serial) + if err != nil { + return errors.Wrap(errPKIRevoke, err) + } + c.Revocation = revTime + } + + if err := cs.repo.Remove(context.Background(), ownerID, c.ID); err != nil { + return errors.Wrap(errRepoRemove, err) + } + return nil +} + +func parseCert(certificate string) (*x509.Certificate, error) { + block, _ := pem.Decode([]byte(certificate)) + if block == nil { + return nil, errParseCert + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + return cert, nil } diff --git a/certs/service_test.go b/certs/service_test.go index ab564729c8..8ca8045c54 100644 --- a/certs/service_test.go +++ b/certs/service_test.go @@ -14,7 +14,6 @@ import ( "strconv" "strings" "testing" - "time" "github.com/mainflux/mainflux" bsmocks "github.com/mainflux/mainflux/bootstrap/mocks" @@ -23,6 +22,7 @@ import ( "github.com/mainflux/mainflux/logger" "github.com/mainflux/mainflux/pkg/errors" mfsdk "github.com/mainflux/mainflux/pkg/sdk/go" + "github.com/mainflux/mainflux/pkg/uuid" "github.com/mainflux/mainflux/things" httpapi "github.com/mainflux/mainflux/things/api/things/http" thmocks "github.com/mainflux/mainflux/things/mocks" @@ -36,16 +36,15 @@ const ( email = "user@example.com" token = "token" thingsNum = 1 + name = "certificate name" thingKey = "thingKey" thingID = "1" - ttl = "1h" - certNum = 10 - - cfgAuthTimeout = "1s" + ttl = "1h" caPath = "../docker/ssl/certs/ca.crt" caKeyPath = "../docker/ssl/certs/ca.key" cfgSignHoursValid = "24h" + cfgSignRSABits = 2048 ) func newService(tokens map[string]string) (certs.Service, error) { @@ -66,14 +65,10 @@ func newService(tokens map[string]string) (certs.Service, error) { return nil, err } - authTimeout, err := time.ParseDuration(cfgAuthTimeout) - if err != nil { - return nil, err - } - - pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout) + idp := uuid.NewMock() + pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignRSABits, cfgSignHoursValid) - return certs.New(auth, repo, sdk, pki), nil + return certs.New(auth, repo, idp, pki, sdk), nil } func newThingsService(auth mainflux.AuthServiceClient) things.Service { @@ -95,6 +90,7 @@ func TestIssueCert(t *testing.T) { require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) cases := []struct { + name string token string desc string thingID string @@ -105,262 +101,38 @@ func TestIssueCert(t *testing.T) { { desc: "issue new cert", token: token, + name: name, thingID: thingID, ttl: ttl, err: nil, }, { desc: "issue new cert for non existing thing id", - token: token, - thingID: "2", - ttl: ttl, - err: certs.ErrFailedCertCreation, - }, - { - desc: "issue new cert for non existing thing id", + name: name, token: wrongValue, thingID: thingID, ttl: ttl, err: errors.ErrAuthentication, }, - } - - for _, tc := range cases { - c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.ttl) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - cert, _ := readCert([]byte(c.ClientCert)) - if cert != nil { - assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } - } - -} - -func TestRevokeCert(t *testing.T) { - svc, err := newService(map[string]string{token: email}) - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) - - _, err = svc.IssueCert(context.Background(), token, thingID, ttl) - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) - - cases := []struct { - token string - desc string - thingID string - err error - }{ { - desc: "revoke cert", - token: token, - thingID: thingID, - err: nil, - }, - { - desc: "revoke cert for invalid token", - token: wrongValue, - thingID: thingID, - err: errors.ErrAuthentication, - }, - { - desc: "revoke cert for invalid thing id", + desc: "issue new cert for non existing thing id", + name: name, token: token, thingID: "2", - err: certs.ErrFailedCertRevocation, - }, - } - - for _, tc := range cases { - _, err := svc.RevokeCert(context.Background(), tc.token, tc.thingID) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } - -} - -func TestListCerts(t *testing.T) { - svc, err := newService(map[string]string{token: email}) - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) - - for i := 0; i < certNum; i++ { - _, err = svc.IssueCert(context.Background(), token, thingID, ttl) - require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err)) - } - - cases := []struct { - token string - desc string - thingID string - offset uint64 - limit uint64 - size uint64 - err error - }{ - { - desc: "list all certs with valid token", - token: token, - thingID: thingID, - offset: 0, - limit: certNum, - size: certNum, - err: nil, - }, - { - desc: "list all certs with invalid token", - token: wrongValue, - thingID: thingID, - offset: 0, - limit: certNum, - size: 0, - err: errors.ErrAuthentication, - }, - { - desc: "list half certs with valid token", - token: token, - thingID: thingID, - offset: certNum / 2, - limit: certNum, - size: certNum / 2, - err: nil, - }, - { - desc: "list last cert with valid token", - token: token, - thingID: thingID, - offset: certNum - 1, - limit: certNum, - size: 1, - err: nil, + ttl: ttl, + err: certs.ErrThingRetrieve, }, } for _, tc := range cases { - page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit) - size := uint64(len(page.Certs)) - assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size)) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } -} - -func TestListSerials(t *testing.T) { - svc, err := newService(map[string]string{token: email}) - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) - - var issuedCerts []certs.Cert - for i := 0; i < certNum; i++ { - cert, err := svc.IssueCert(context.Background(), token, thingID, ttl) - require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err)) - - crt := certs.Cert{ - OwnerID: cert.OwnerID, - ThingID: cert.ThingID, - Serial: cert.Serial, - Expire: cert.Expire, + c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.name, tc.ttl) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) + cert, _ := readCert([]byte(c.Certificate)) + if cert != nil { + assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) } - issuedCerts = append(issuedCerts, crt) - } - - cases := []struct { - token string - desc string - thingID string - offset uint64 - limit uint64 - certs []certs.Cert - err error - }{ - { - desc: "list all certs with valid token", - token: token, - thingID: thingID, - offset: 0, - limit: certNum, - certs: issuedCerts, - err: nil, - }, - { - desc: "list all certs with invalid token", - token: wrongValue, - thingID: thingID, - offset: 0, - limit: certNum, - certs: nil, - err: errors.ErrAuthentication, - }, - { - desc: "list half certs with valid token", - token: token, - thingID: thingID, - offset: certNum / 2, - limit: certNum, - certs: issuedCerts[certNum/2:], - err: nil, - }, - { - desc: "list last cert with valid token", - token: token, - thingID: thingID, - offset: certNum - 1, - limit: certNum, - certs: []certs.Cert{issuedCerts[certNum-1]}, - err: nil, - }, } - for _, tc := range cases { - page, err := svc.ListSerials(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit) - assert.Equal(t, tc.certs, page.Certs, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.certs, page.Certs)) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } -} - -func TestViewCert(t *testing.T) { - svc, err := newService(map[string]string{token: email}) - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) - - ic, err := svc.IssueCert(context.Background(), token, thingID, ttl) - require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err)) - - cert := certs.Cert{ - ThingID: thingID, - ClientCert: ic.ClientCert, - Serial: ic.Serial, - Expire: ic.Expire, - } - - cases := []struct { - token string - desc string - serialID string - cert certs.Cert - err error - }{ - { - desc: "list cert with valid token and serial", - token: token, - serialID: cert.Serial, - cert: cert, - err: nil, - }, - { - desc: "list cert with invalid token", - token: wrongValue, - serialID: cert.Serial, - cert: certs.Cert{}, - err: errors.ErrAuthentication, - }, - { - desc: "list cert with invalid serial", - token: token, - serialID: wrongValue, - cert: certs.Cert{}, - err: errors.ErrNotFound, - }, - } - - for _, tc := range cases { - cert, err := svc.ViewCert(context.Background(), tc.token, tc.serialID) - assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.cert, cert)) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } } func newThingsServer(svc things.Service) *httptest.Server { diff --git a/certs/tracing/certs.go b/certs/tracing/certs.go new file mode 100644 index 0000000000..cf6bf52ebc --- /dev/null +++ b/certs/tracing/certs.go @@ -0,0 +1,104 @@ +// Copyright (c) Mainflux +// SPDX-License-Identifier: Apache-2.0 + +package tracing + +import ( + "context" + + "github.com/mainflux/mainflux/certs" + opentracing "github.com/opentracing/opentracing-go" +) + +const ( + saveCertsOp = "save_certs" + updateCertsOp = "update_certs" + retrieveCertsOp = "retrieve_certs" + removeCertsOp = "retrieve_certs" + retrieveThingCertsOp = "retrieve_thing_certs" + removeThingCertsOp = "retrieve_thing_certs" +) + +var ( + _ certs.Repository = (*certsRepositoryMiddleware)(nil) +) + +type certsRepositoryMiddleware struct { + tracer opentracing.Tracer + repo certs.Repository +} + +// ChannelRepositoryMiddleware tracks request and their latency, and adds spans +// to context. +func CertsRepositoryMiddleware(tracer opentracing.Tracer, repo certs.Repository) certs.Repository { + return certsRepositoryMiddleware{ + tracer: tracer, + repo: repo, + } +} + +func (crm certsRepositoryMiddleware) Save(ctx context.Context, cert certs.Cert) error { + span := createSpan(ctx, crm.tracer, saveCertsOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return crm.repo.Save(ctx, cert) +} + +func (crm certsRepositoryMiddleware) Update(ctx context.Context, ownerID string, cert certs.Cert) error { + span := createSpan(ctx, crm.tracer, updateCertsOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return crm.repo.Update(ctx, ownerID, cert) +} + +func (crm certsRepositoryMiddleware) Retrieve(ctx context.Context, ownerID, certID, name, thingID, serial string, status certs.Status, offset uint64, limit int64) (certs.Page, error) { + span := createSpan(ctx, crm.tracer, retrieveCertsOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return crm.repo.Retrieve(ctx, ownerID, certID, name, thingID, serial, status, offset, limit) +} + +func (crm certsRepositoryMiddleware) RetrieveCount(ctx context.Context, ownerID, certID, name, thingID, serial string, status certs.Status) (uint64, error) { + span := createSpan(ctx, crm.tracer, removeThingCertsOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return crm.repo.RetrieveCount(ctx, ownerID, certID, name, thingID, serial, status) +} + +func (crm certsRepositoryMiddleware) Remove(ctx context.Context, ownerID, certID string) error { + span := createSpan(ctx, crm.tracer, removeCertsOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return crm.repo.Remove(ctx, ownerID, certID) +} + +func (crm certsRepositoryMiddleware) RetrieveThingCerts(ctx context.Context, thingID string) (certs.Page, error) { + span := createSpan(ctx, crm.tracer, retrieveThingCertsOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return crm.repo.RetrieveThingCerts(ctx, thingID) +} + +func (crm certsRepositoryMiddleware) RemoveThingCerts(ctx context.Context, thingID string) error { + span := createSpan(ctx, crm.tracer, removeThingCertsOp) + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + + return crm.repo.RemoveThingCerts(ctx, thingID) +} + +func createSpan(ctx context.Context, tracer opentracing.Tracer, opName string) opentracing.Span { + if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil { + return tracer.StartSpan( + opName, + opentracing.ChildOf(parentSpan.Context()), + ) + } + return tracer.StartSpan(opName) +} diff --git a/cmd/certs/main.go b/cmd/certs/main.go index 32b9e9685e..29340806b3 100644 --- a/cmd/certs/main.go +++ b/cmd/certs/main.go @@ -10,30 +10,37 @@ import ( "log" "os" - "github.com/go-redis/redis/v8" - "github.com/mainflux/mainflux" "github.com/mainflux/mainflux/certs" "github.com/mainflux/mainflux/certs/api" + "github.com/mainflux/mainflux/certs/eventhandlers" vault "github.com/mainflux/mainflux/certs/pki" certsPg "github.com/mainflux/mainflux/certs/postgres" "github.com/mainflux/mainflux/internal" + "github.com/mainflux/mainflux/internal/clients/events/things" + redisClient "github.com/mainflux/mainflux/internal/clients/redis" "github.com/mainflux/mainflux/internal/env" "github.com/mainflux/mainflux/internal/server" httpserver "github.com/mainflux/mainflux/internal/server/http" + "github.com/mainflux/mainflux/internal/sqlxt" "github.com/mainflux/mainflux/logger" "golang.org/x/sync/errgroup" - "github.com/jmoiron/sqlx" authClient "github.com/mainflux/mainflux/internal/clients/grpc/auth" pgClient "github.com/mainflux/mainflux/internal/clients/postgres" "github.com/mainflux/mainflux/pkg/errors" mfsdk "github.com/mainflux/mainflux/pkg/sdk/go" + "github.com/mainflux/mainflux/pkg/uuid" ) const ( - svcName = "certs" - envPrefix = "MF_CERTS_" - envPrefixHttp = "MF_CERTS_HTTP_" + svcName = "certs" + esGroup = "mainflux.certs" + esConsumer = "certs" + + envPrefix = "MF_CERTS_" + envPrefixHttp = "MF_CERTS_HTTP_" + envPrefixES = "MF_CERTS_ES_" + defDB = "certs" defSvcHttpPort = "8204" ) @@ -83,6 +90,8 @@ func main() { if err != nil { logger.Error("Failed to load CA certificates for issuing client certs") } + _ = tlsCert + _ = caCert if cfg.PkiHost == "" { log.Fatalf("No host specified for PKI engine") @@ -107,7 +116,23 @@ func main() { defer authHandler.Close() logger.Info("Successfully connected to auth grpc server " + authHandler.Secure()) - svc := newService(auth, db, logger, nil, tlsCert, caCert, cfg, pkiClient) + dbt := sqlxt.NewDatabase(db) + + certsRepo := certsPg.NewRepository(dbt) + + config := mfsdk.Config{ + CertsURL: cfg.CertsURL, + ThingsURL: cfg.ThingsURL, + } + sdk := mfsdk.NewSDK(config) + + idProvider := uuid.New() + + svc := certs.New(auth, certsRepo, idProvider, pkiClient, sdk) + + svc = api.NewLoggingMiddleware(svc, logger) + counter, latency := internal.MakeMetrics(svcName, "api") + svc = api.MetricsMiddleware(svc, counter, latency) httpServerConfig := server.Config{Port: defSvcHttpPort} if err := env.Parse(&httpServerConfig, env.Options{Prefix: envPrefixHttp, AltPrefix: envPrefix}); err != nil { @@ -115,10 +140,23 @@ func main() { } hs := httpserver.New(ctx, cancel, svcName, httpServerConfig, api.MakeHandler(svc, logger), logger) + thingsESClient, err := redisClient.Setup(envPrefixES) + if err != nil { + log.Fatalf(err.Error()) + } + defer thingsESClient.Close() + + certsThingsHandler := eventhandlers.NewThingsEventHandlers(certsRepo, pkiClient) + te := things.NewEventStore(certsThingsHandler, thingsESClient, esConsumer, logger) + g.Go(func() error { return hs.Start() }) + g.Go(func() error { + return te.Subscribe(ctx, esGroup) + }) + g.Go(func() error { return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) }) @@ -128,20 +166,6 @@ func main() { } } -func newService(auth mainflux.AuthServiceClient, db *sqlx.DB, logger logger.Logger, esClient *redis.Client, tlsCert tls.Certificate, x509Cert *x509.Certificate, cfg config, pkiAgent vault.Agent) certs.Service { - certsRepo := certsPg.NewRepository(db, logger) - config := mfsdk.Config{ - CertsURL: cfg.CertsURL, - ThingsURL: cfg.ThingsURL, - } - sdk := mfsdk.NewSDK(config) - svc := certs.New(auth, certsRepo, sdk, pkiAgent) - svc = api.NewLoggingMiddleware(svc, logger) - counter, latency := internal.MakeMetrics(svcName, "api") - svc = api.MetricsMiddleware(svc, counter, latency) - return svc -} - func loadCertificates(conf config) (tls.Certificate, *x509.Certificate, error) { var tlsCert tls.Certificate var caCert *x509.Certificate diff --git a/docker/.env b/docker/.env index ebdb547855..bb3b517ef7 100644 --- a/docker/.env +++ b/docker/.env @@ -139,19 +139,19 @@ MF_BOOTSTRAP_DB_SSL_MODE=disable ### Provision MF_PROVISION_CONFIG_FILE=/configs/config.toml MF_PROVISION_LOG_LEVEL=debug -MF_PROVISION_HTTP_PORT=8190 +MF_PROVISION_HTTP_PORT=8191 MF_PROVISION_ENV_CLIENTS_TLS=false MF_PROVISION_SERVER_CERT= MF_PROVISION_SERVER_KEY= -MF_PROVISION_USERS_LOCATION=http://users:8180 -MF_PROVISION_THINGS_LOCATION=http://things:8182 +MF_PROVISION_USERS_LOCATION=http://mainflux-users:8180 +MF_PROVISION_THINGS_LOCATION=http://mainflux-things:8182 MF_PROVISION_USER= MF_PROVISION_PASS= MF_PROVISION_API_KEY= -MF_PROVISION_CERTS_SVC_URL=http://certs:8204 +MF_PROVISION_CERTS_SVC_URL=http://mainflux-certs:8204 MF_PROVISION_X509_PROVISIONING=false -MF_PROVISION_BS_SVC_URL=http://bootstrap:8202/things -MF_PROVISION_BS_SVC_WHITELIST_URL=http://bootstrap:8202/things/state +MF_PROVISION_BS_SVC_URL=http://mainflux-bootstrap:8202/things +MF_PROVISION_BS_SVC_WHITELIST_URL=http://mainflux-bootstrap:8202/things/state MF_PROVISION_BS_CONFIG_PROVISIONING=true MF_PROVISION_BS_AUTO_WHITELIST=true MF_PROVISION_BS_CONTENT= @@ -178,7 +178,20 @@ MF_CERTS_SIGN_CA_PATH=/etc/ssl/certs/ca.crt MF_CERTS_SIGN_CA_KEY_PATH=/etc/ssl/certs/ca.key MF_CERTS_SIGN_HOURS_VALID=2048h MF_CERTS_SIGN_RSA_BITS=2048 -MF_CERTS_VAULT_HOST=http://vault:8200 +MF_CERTS_VAULT_HOST=http://mainflux-vault:8200 +MF_CERTS_BOOTSTRAP_URL=http://mainflux-bootstrap:8202 +MF_CERTS_BSCLIENT_USER=admin@example.com +MF_CERTS_BSCLIENT_PASS=12345678 +MF_CERTS_EVENT_CONSUMER=certs +MF_CERTS_USERS_TOKEN= +MF_CERTS_AUTO_RENEW_INTERVAL=10s +MF_CERTS_AUTO_RENEW=true +MF_CERTS_AUTO_RENEW_UDPATE_BS=true +MF_CERTS_STOP_ON_AUTO_RENEW_ERROR=false +MF_CERTS_USERS_URL=http://mainflux-users:8180 +MF_CERTS_ES_URL=mainflux-es-redis:6379 +MF_CERTS_ES_PASS= +MF_CERTS_ES_DB=0 ### Vault diff --git a/docker/addons/certs/docker-compose.yml b/docker/addons/certs/docker-compose.yml index d0bb00c768..20f09125c9 100644 --- a/docker/addons/certs/docker-compose.yml +++ b/docker/addons/certs/docker-compose.yml @@ -70,6 +70,9 @@ services: MF_AUTH_GRPC_URL: ${MF_AUTH_GRPC_URL} MF_AUTH_GRPC_TIMEOUT: ${MF_AUTH_GRPC_TIMEOUT} MF_CERTS_VAULT_HOST: ${MF_CERTS_VAULT_HOST} + MF_CERTS_ES_URL: ${MF_CERTS_ES_URL} + MF_CERTS_ES_PASS: ${MF_CERTS_ES_PASS} + MF_CERTS_ES_DB: ${MF_CERTS_ES_DB} volumes: - ../../ssl/certs/ca.key:/etc/ssl/certs/ca.key - ../../ssl/certs/ca.crt:/etc/ssl/certs/ca.crt diff --git a/internal/apiutil/errors.go b/internal/apiutil/errors.go index 137d76454f..06c508c40d 100644 --- a/internal/apiutil/errors.go +++ b/internal/apiutil/errors.go @@ -101,4 +101,10 @@ var ( // ErrBootstrapState indicates an invalid boostrap state. ErrBootstrapState = errors.New("invalid bootstrap state") + + // ErrMissingThingKey indicates an missing thing key + ErrMissingThingKey = errors.New("missing thing key") + + // ErrMissingThingID indicates an missing thing ID + ErrMissingThingID = errors.New("missing thing ID") ) diff --git a/internal/apiutil/transport.go b/internal/apiutil/transport.go index 57488d6ad6..f921a28319 100644 --- a/internal/apiutil/transport.go +++ b/internal/apiutil/transport.go @@ -74,6 +74,26 @@ func ReadUintQuery(r *http.Request, key string, def uint64) (uint64, error) { return val, nil } +// ReadIntQuery reads the value of uint64 http query parameters for a given key +func ReadIntQuery(r *http.Request, key string, def int64) (int64, error) { + vals := bone.GetQuery(r, key) + if len(vals) > 1 { + return 0, errors.ErrInvalidQueryParams + } + + if len(vals) == 0 { + return def, nil + } + + strval := vals[0] + val, err := strconv.ParseInt(strval, 10, 64) + if err != nil { + return 0, errors.ErrInvalidQueryParams + } + + return val, nil +} + // ReadStringQuery reads the value of string http query parameters for a given key func ReadStringQuery(r *http.Request, key string, def string) (string, error) { vals := bone.GetQuery(r, key) diff --git a/internal/clients/events/things/client.go b/internal/clients/events/things/client.go new file mode 100644 index 0000000000..a127ae155b --- /dev/null +++ b/internal/clients/events/things/client.go @@ -0,0 +1,183 @@ +package things + +import ( + "context" + "fmt" + + "github.com/go-redis/redis/v8" + "github.com/mainflux/mainflux/logger" + "github.com/mitchellh/mapstructure" +) + +const ( + stream = "mainflux.things" + + thingPrefix = "thing." + thingCreate = thingPrefix + "create" + thingUpdate = thingPrefix + "update" + thingRemove = thingPrefix + "remove" + thingConnect = thingPrefix + "connect" + thingDisconnect = thingPrefix + "disconnect" + + channelPrefix = "channel." + channelCreate = channelPrefix + "create" + channelUpdate = channelPrefix + "update" + channelRemove = channelPrefix + "remove" + + exists = "BUSYGROUP Consumer Group name already exists" + + msgEventMessage = "Failed to parse the event message %s : %v" + msgEventHandler = "Failed to execute the event handler of event %s : %v" +) + +type EventHandler interface { + ThingCreated(ctx context.Context, cte CreateThingEvent) error + ThingUpdated(ctx context.Context, ute UpdateThingEvent) error + ThingRemoved(ctx context.Context, rte RemoveThingEvent) error + + ChannelCreated(ctx context.Context, cce CreateChannelEvent) error + ChannelUpdated(ctx context.Context, uce UpdateChannelEvent) error + ChannelRemoved(ctx context.Context, rce RemoveChannelEvent) error + + ThingConnected(ctx context.Context, cte ConnectThingEvent) error + ThingDisconnected(ctx context.Context, dte DisconnectThingEvent) error +} + +type Event struct { + eh EventHandler + client *redis.Client + consumer string + logger logger.Logger +} + +func NewEventStore(eh EventHandler, client *redis.Client, consumer string, log logger.Logger) Event { + return Event{ + eh: eh, + client: client, + consumer: consumer, + logger: log, + } +} + +func (e Event) Subscribe(ctx context.Context, group string) error { + err := e.client.XGroupCreateMkStream(ctx, stream, group, "$").Err() + if err != nil && err.Error() != exists { + return err + } + + for { + select { + case <-ctx.Done(): + return nil + default: + streams, err := e.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: group, + Consumer: e.consumer, + Streams: []string{stream, ">"}, + Count: 100, + }).Result() + if err != nil || len(streams) == 0 { + continue + } + + for _, msg := range streams[0].Messages { + event := msg.Values + + switch event["operation"] { + case thingCreate: + cte := CreateThingEvent{} + if err := decodeEvent(event, &cte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, thingCreate, err)) + break + } + if err = e.eh.ThingCreated(ctx, cte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, thingCreate, err)) + break + } + + case thingUpdate: + ute := UpdateThingEvent{} + if err := decodeEvent(event, &ute); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, thingUpdate, err)) + break + } + if err = e.eh.ThingUpdated(ctx, ute); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, thingUpdate, err)) + break + } + + case thingRemove: + rte := RemoveThingEvent{} + if err := decodeEvent(event, &rte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, thingRemove, err)) + break + } + if err = e.eh.ThingRemoved(ctx, rte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, thingRemove, err)) + break + } + + case channelCreate: + cce := CreateChannelEvent{} + if err := decodeEvent(event, &cce); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, channelCreate, err)) + break + } + if err = e.eh.ChannelCreated(ctx, cce); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, channelCreate, err)) + break + } + + case channelUpdate: + uce := UpdateChannelEvent{} + if err := decodeEvent(event, &uce); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, channelUpdate, err)) + break + } + if err = e.eh.ChannelUpdated(ctx, uce); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, channelUpdate, err)) + break + } + + case channelRemove: + rce := RemoveChannelEvent{} + if err := decodeEvent(event, &rce); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, channelRemove, err)) + break + } + if err = e.eh.ChannelRemoved(ctx, rce); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, channelRemove, err)) + break + } + + case thingConnect: + cte := ConnectThingEvent{} + if err := decodeEvent(event, &cte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, thingConnect, err)) + break + } + if err = e.eh.ThingConnected(ctx, cte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, thingConnect, err)) + break + } + + case thingDisconnect: + dte := DisconnectThingEvent{} + if err := decodeEvent(event, &dte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventMessage, thingConnect, err)) + break + } + if err = e.eh.ThingDisconnected(ctx, dte); err != nil { + e.logger.Error(fmt.Sprintf(msgEventHandler, thingConnect, err)) + break + } + } + e.client.XAck(ctx, stream, group, msg.ID) + } + } + } +} + +func decodeEvent[T Type](event map[string]interface{}, obj *T) error { + return mapstructure.Decode(event, obj) +} diff --git a/internal/clients/events/things/event_types.go b/internal/clients/events/things/event_types.go new file mode 100644 index 0000000000..477aab9d17 --- /dev/null +++ b/internal/clients/events/things/event_types.go @@ -0,0 +1,49 @@ +package things + +type CreateThingEvent struct { + ID string `mapstructure:"id"` + Owner string `mapstructure:"owner"` + Name string `mapstructure:"name"` + Metadata map[string]interface{} `mapstructure:"metadata"` +} + +type UpdateThingEvent struct { + ID string `mapstructure:"id"` + Name string `mapstructure:"name"` + Metadata map[string]interface{} `mapstructure:"metadata"` +} + +type RemoveThingEvent struct { + ID string `mapstructure:"id"` +} + +type CreateChannelEvent struct { + ID string `mapstructure:"id"` + Owner string `mapstructure:"owner"` + Name string `mapstructure:"name"` + Metadata map[string]interface{} `mapstructure:"metadata"` +} + +type UpdateChannelEvent struct { + ID string `mapstructure:"id"` + Name string `mapstructure:"name"` + Metadata map[string]interface{} `mapstructure:"metadata"` +} + +type RemoveChannelEvent struct { + ID string `mapstructure:"id"` +} + +type ConnectThingEvent struct { + ChanID string `mapstructure:"chan_id"` + ThingID string `mapstructure:"thing_id"` +} + +type DisconnectThingEvent struct { + ChanID string `mapstructure:"chan_id"` + ThingID string `mapstructure:"thing_id"` +} + +type Type interface { + CreateThingEvent | UpdateThingEvent | RemoveThingEvent | CreateChannelEvent | UpdateChannelEvent | RemoveChannelEvent | ConnectThingEvent | DisconnectThingEvent +} diff --git a/internal/clients/postgres/checkerror.go b/internal/clients/postgres/checkerror.go new file mode 100644 index 0000000000..817ba15636 --- /dev/null +++ b/internal/clients/postgres/checkerror.go @@ -0,0 +1,43 @@ +package postgres + +import ( + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" + "github.com/mainflux/mainflux/pkg/errors" +) + +type Operation int + +const ( + Create Operation = iota + View + Update + Remove +) + +func CheckError(err error, op Operation) error { + if pErr, ok := err.(*pgconn.PgError); ok { + switch pErr.Code { + case pgerrcode.UniqueViolation: + return errors.Wrap(errors.ErrConflict, err) + case pgerrcode.InvalidTextRepresentation: + return errors.Wrap(errors.ErrMalformedEntity, err) + case pgerrcode.ForeignKeyViolation: + return errors.Wrap(errors.ErrConflict, err) + case pgerrcode.StringDataRightTruncationDataException: + return errors.Wrap(errors.ErrMalformedEntity, err) + } + + switch op { + case Create: + return errors.Wrap(errors.ErrCreateEntity, pErr) + case View: + return errors.Wrap(errors.ErrViewEntity, pErr) + case Update: + return errors.Wrap(errors.ErrUpdateEntity, pErr) + case Remove: + return errors.Wrap(errors.ErrRemoveEntity, pErr) + } + } + return err +} diff --git a/internal/sqlxt/database.go b/internal/sqlxt/database.go new file mode 100644 index 0000000000..89639a1d8a --- /dev/null +++ b/internal/sqlxt/database.go @@ -0,0 +1,120 @@ +package sqlxt + +import ( + "context" + "database/sql" + + "github.com/jmoiron/sqlx" + "github.com/mainflux/mainflux/pkg/errors" + "github.com/opentracing/opentracing-go" +) + +var ( + ErrRollback = errors.New("failed to rollback transaction") + ErrCommit = errors.New("failed to commit transaction") + ErrResultRowsAffected = errors.New("failed to get result rows affected") +) +var _ Database = (*database)(nil) + +type database struct { + db *sqlx.DB +} + +// Database provides a database interface +type Database interface { + NamedCUDContext(ctx context.Context, query string, args interface{}) (int64, error, error) + NamedTotalQueryContext(ctx context.Context, query string, params interface{}) (uint64, error) + NamedExecContext(context.Context, string, interface{}) (sql.Result, error) + QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row + NamedQueryContext(context.Context, string, interface{}) (*sqlx.Rows, error) + GetContext(context.Context, interface{}, string, ...interface{}) error + BeginTxx(context.Context, *sql.TxOptions) (*sqlx.Tx, error) +} + +// NewDatabase creates a ThingDatabase instance +func NewDatabase(db *sqlx.DB) Database { + return &database{ + db: db, + } +} + +func (dm database) NamedCUDContext(ctx context.Context, query string, args interface{}) (int64, error, error) { + tx, err := dm.BeginTxx(ctx, nil) + if err != nil { + return 0, err, nil + } + result, err := tx.NamedExecContext(ctx, query, args) + if err != nil { + errRoll := tx.Rollback() + if errRoll != nil { + return 0, err, errors.Wrap(ErrRollback, errRoll) + } + return 0, err, nil + } + + if err := tx.Commit(); err != nil { + return 0, nil, errors.Wrap(ErrCommit, err) + } + + count, err := result.RowsAffected() + if err != nil { + return count, nil, errors.Wrap(ErrResultRowsAffected, err) + } + + return count, nil, nil +} + +func (dm database) NamedTotalQueryContext(ctx context.Context, query string, params interface{}) (uint64, error) { + rows, err := dm.NamedQueryContext(ctx, query, params) + if err != nil { + return 0, err + } + defer rows.Close() + total := uint64(0) + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + return total, nil +} + +func (dm database) NamedExecContext(ctx context.Context, query string, args interface{}) (sql.Result, error) { + addSpanTags(ctx, query) + return dm.db.NamedExecContext(ctx, query, args) +} + +func (dm database) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + addSpanTags(ctx, query) + return dm.db.QueryRowxContext(ctx, query, args...) +} + +func (dm database) NamedQueryContext(ctx context.Context, query string, args interface{}) (*sqlx.Rows, error) { + addSpanTags(ctx, query) + return dm.db.NamedQueryContext(ctx, query, args) +} + +func (dm database) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + addSpanTags(ctx, query) + return dm.db.GetContext(ctx, dest, query, args...) +} + +func (dm database) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*sqlx.Tx, error) { + span := opentracing.SpanFromContext(ctx) + if span != nil { + span.SetTag("span.kind", "client") + span.SetTag("peer.service", "postgres") + span.SetTag("db.type", "sql") + } + return dm.db.BeginTxx(ctx, opts) +} + +func addSpanTags(ctx context.Context, query string) { + span := opentracing.SpanFromContext(ctx) + if span != nil { + span.SetTag("sql.statement", query) + span.SetTag("span.kind", "client") + span.SetTag("peer.service", "postgres") + span.SetTag("db.type", "sql") + } +} diff --git a/pkg/errors/sdk_errors.go b/pkg/errors/sdk_errors.go index e039310e24..baf03f6b92 100644 --- a/pkg/errors/sdk_errors.go +++ b/pkg/errors/sdk_errors.go @@ -33,6 +33,7 @@ type sdkError struct { statusCode int } + func (ce *sdkError) Error() string { if ce == nil { return "" @@ -69,6 +70,7 @@ func NewSDKErrorWithStatus(err error, statusCode int) SDKError { } } + // CheckError will check the HTTP response status code and matches it with the given status codes. // Since multiple status codes can be valid, we can pass multiple status codes to the function. // The function then checks for errors in the HTTP response. diff --git a/pkg/sdk/go/bootstrap.go b/pkg/sdk/go/bootstrap.go index 55aae01f2f..be6ef399a8 100644 --- a/pkg/sdk/go/bootstrap.go +++ b/pkg/sdk/go/bootstrap.go @@ -16,7 +16,7 @@ import ( const configsEndpoint = "configs" const bootstrapEndpoint = "bootstrap" const whitelistEndpoint = "state" -const bootstrapCertsEndpoint = "configs/certs" +const bootstrapCertsEndpoint = "things/configs/certs" // BootstrapConfig represents Configuration entity. It wraps information about external entity // as well as info about corresponding Mainflux entities. @@ -116,6 +116,7 @@ func (sdk mfSDK) UpdateBootstrapCerts(id, clientCert, clientKey, ca, token strin return errors.NewSDKError(err) } + _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, token, string(CTJSON), data, http.StatusOK) return sdkerr }