From 1291ba7bfc277763ff111ad126ed6de6628b9c79 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 20 Nov 2024 14:45:05 +0300 Subject: [PATCH 01/10] Intial implementation of CSR Signed-off-by: nyagamunene --- certs.go | 36 ++++++++++++++++++++++++++++++ postgres/{ => certs}/certs.go | 0 postgres/{ => certs}/certs_test.go | 0 postgres/{ => certs}/init.go | 0 postgres/{ => certs}/setup_test.go | 0 postgres/csr/csr.go | 0 postgres/csr/init.go | 34 ++++++++++++++++++++++++++++ 7 files changed, 70 insertions(+) rename postgres/{ => certs}/certs.go (100%) rename postgres/{ => certs}/certs_test.go (100%) rename postgres/{ => certs}/init.go (100%) rename postgres/{ => certs}/setup_test.go (100%) create mode 100644 postgres/csr/csr.go create mode 100644 postgres/csr/init.go diff --git a/certs.go b/certs.go index 7a88bee..673b6bb 100644 --- a/certs.go +++ b/certs.go @@ -32,6 +32,30 @@ type PageMetadata struct { EntityID string `json:"entity_id,omitempty" db:"entity_id"` } +type CSRMetadata struct { + CommonName string `json:"common_name"` + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + EmailAddress string `json:"email_address"` + DNSNames []string `json:"dns_names"` + IPAddresses []string `json:"ip_addresses"` +} + +type CSR struct { + CSR []byte `json:"csr"` + PrivateKey []byte `json:"private_key"` + EntityID string `json:"entity_id"` + Status string `json:"status"` + SubmittedAt time.Time `json:"submitted_at"` + ProcessedAt time.Time `json:"processed_at"` + SerialNumber string `json:"serial_number"` +} + type Service interface { // RenewCert renews a certificate from the database. RenewCert(ctx context.Context, serialNumber string) error @@ -73,6 +97,18 @@ type Service interface { // RemoveCert deletes a cert for a provided entityID. RemoveCert(ctx context.Context, entityId string) error + + // CreateCSR creates a new Certificate Signing Request + CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string) (CSR, error) + + // ProcessCSR processes a pending CSR and either approves or rejects it + ProcessCSR(ctx context.Context, csrID string, approve bool) error + + // ListCSRs returns a list of CSRs based on filter criteria + ListCSRs(ctx context.Context, entityID string, status string) ([]CSR, error) + + // RetrieveCSR retrieves a specific CSR by ID + RetrieveCSR(ctx context.Context, csrID string) (CSR, error) } type Repository interface { diff --git a/postgres/certs.go b/postgres/certs/certs.go similarity index 100% rename from postgres/certs.go rename to postgres/certs/certs.go diff --git a/postgres/certs_test.go b/postgres/certs/certs_test.go similarity index 100% rename from postgres/certs_test.go rename to postgres/certs/certs_test.go diff --git a/postgres/init.go b/postgres/certs/init.go similarity index 100% rename from postgres/init.go rename to postgres/certs/init.go diff --git a/postgres/setup_test.go b/postgres/certs/setup_test.go similarity index 100% rename from postgres/setup_test.go rename to postgres/certs/setup_test.go diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go new file mode 100644 index 0000000..e69de29 diff --git a/postgres/csr/init.go b/postgres/csr/init.go new file mode 100644 index 0000000..9f52b7a --- /dev/null +++ b/postgres/csr/init.go @@ -0,0 +1,34 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + _ "github.com/jackc/pgx/v5/stdlib" + migrate "github.com/rubenv/sql-migrate" +) + +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "certs_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS certs ( + serial_number VARCHAR(40) UNIQUE NOT NULL, + certificate TEXT, + key TEXT, + revoked BOOLEAN, + expiry_time TIMESTAMP, + entity_id VARCHAR(36), + type TEXT CHECK (type IN ('RootCA', 'IntermediateCA', 'ClientCert')), + PRIMARY KEY (serial_number) + )`, + }, + Down: []string{ + "DROP TABLE certs", + }, + }, + }, + } +} From be29338e224e1c5770871014edc0f16aa720848f Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 20 Nov 2024 18:22:40 +0300 Subject: [PATCH 02/10] Add repository for CSR Signed-off-by: nyagamunene --- certs.go | 26 ++++++-- cmd/certs/main.go | 6 +- postgres/csr/csr.go | 154 +++++++++++++++++++++++++++++++++++++++++++ postgres/csr/init.go | 22 +++---- service.go | 142 +++++++++++++++++++++++++++++++++++++-- 5 files changed, 324 insertions(+), 26 deletions(-) diff --git a/certs.go b/certs.go index 673b6bb..c1c221f 100644 --- a/certs.go +++ b/certs.go @@ -47,13 +47,18 @@ type CSRMetadata struct { } type CSR struct { - CSR []byte `json:"csr"` - PrivateKey []byte `json:"private_key"` - EntityID string `json:"entity_id"` - Status string `json:"status"` - SubmittedAt time.Time `json:"submitted_at"` - ProcessedAt time.Time `json:"processed_at"` - SerialNumber string `json:"serial_number"` + CSR []byte `json:"csr" db:"csr"` + PrivateKey []byte `json:"private_key" db:"private_key"` + EntityID string `json:"entity_id" db:"entity_id"` + Status string `json:"status" db:"status"` + SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` + ProcessedAt time.Time `json:"processed_at" db:"processed_at"` + SerialNumber string `json:"serial_number" db:"serial_number"` +} + +type CSRPage struct { + PageMetadata + CSRs []CSR } type Service interface { @@ -133,3 +138,10 @@ type Repository interface { // RemoveCert deletes cert from database. RemoveCert(ctx context.Context, entityId string) error } + +type CSRRepository interface { + CreateCSR(context.Context, CSR) error + UpdateCSR(context.Context, CSR) error + ListCSRs(context.Context, PageMetadata) (CSRPage, error) + RetrieveCSR(context.Context, string) (CSR, error) +} diff --git a/cmd/certs/main.go b/cmd/certs/main.go index e596483..6047ac1 100644 --- a/cmd/certs/main.go +++ b/cmd/certs/main.go @@ -24,7 +24,8 @@ import ( grpcserver "github.com/absmach/certs/internal/server/grpc" httpserver "github.com/absmach/certs/internal/server/http" "github.com/absmach/certs/internal/uuid" - cpostgres "github.com/absmach/certs/postgres" + cpostgres "github.com/absmach/certs/postgres/certs" + csrpostgres "github.com/absmach/certs/postgres/csr" "github.com/absmach/certs/tracing" "github.com/caarlos0/env/v10" "github.com/jmoiron/sqlx" @@ -146,7 +147,8 @@ func main() { func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, logger *slog.Logger, dbConfig pgClient.Config, config *certs.Config) (certs.Service, error) { database := postgres.NewDatabase(db, dbConfig, tracer) repo := cpostgres.NewRepository(database) - svc, err := certs.NewService(ctx, repo, config) + csrRepo := csrpostgres.NewRepository(database) + svc, err := certs.NewService(ctx, repo, csrRepo, config) if err != nil { return nil, err } diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index e69de29..f4d52dc 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -0,0 +1,154 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/absmach/certs" + "github.com/absmach/certs/errors" + "github.com/absmach/certs/internal/postgres" + "github.com/jackc/pgx/v5/pgconn" +) + +// Postgres error codes: +// https://www.postgresql.org/docs/current/errcodes-appendix.html +const ( + errDuplicate = "23505" // unique_violation + errTruncation = "22001" // string_data_right_truncation + errFK = "23503" // foreign_key_violation + errInvalid = "22P02" // invalid_text_representation + errUntranslatable = "22P05" // untranslatable_character + errInvalidChar = "22021" // character_not_in_repertoire +) + +var ( + ErrConflict = errors.New("entity already exists") + ErrMalformedEntity = errors.New("malformed entity") + ErrCreateEntity = errors.New("failed to create entity") +) + +type CSRRepo struct { + db postgres.Database +} + +func NewRepository(db postgres.Database) certs.CSRRepository { + return CSRRepo{ + db: db, + } +} + +func (repo CSRRepo) CreateCSR(ctx context.Context, cert certs.CSR) error { + q := ` + INSERT INTO certs (serial_number, csr, private_key, entity_id, status, submitted_at, processed_at) + VALUES (:serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :processed_at)` + _, err := repo.db.NamedExecContext(ctx, q, cert) + if err != nil { + return handleError(certs.ErrCreateEntity, err) + } + return nil +} + +func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { + q := `UPDATE certs SET certificate = :certificate, key = :key, revoked = :revoked, expiry_time = :expiry_time WHERE serial_number = :serial_number` + res, err := repo.db.NamedExecContext(ctx, q, cert) + if err != nil { + return handleError(certs.ErrUpdateEntity, err) + } + count, err := res.RowsAffected() + if err != nil { + return errors.Wrap(certs.ErrUpdateEntity, err) + } + if count == 0 { + return certs.ErrNotFound + } + return nil +} + +func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error) { + q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM certs WHERE serial_number = $1` + var csr certs.CSR + if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csr); err != nil { + if err == sql.ErrNoRows { + return certs.CSR{}, errors.Wrap(certs.ErrNotFound, err) + } + return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, err) + } + return csr, nil +} + +func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + q := `SELECT serial_number, status, submitted_at, processed_at, entity_id FROM certs %s LIMIT :limit OFFSET :offset` + var condition string + if pm.EntityID != "" { + condition = `WHERE entity_id = :entity_id` + } else { + condition = `` + } + q = fmt.Sprintf(q, condition) + var csrs []certs.CSR + + params := map[string]interface{}{ + "limit": pm.Limit, + "offset": pm.Offset, + "entity_id": pm.EntityID, + } + rows, err := repo.db.NamedQueryContext(ctx, q, params) + if err != nil { + return certs.CSRPage{}, handleError(certs.ErrViewEntity, err) + } + defer rows.Close() + + for rows.Next() { + csr := &certs.CSR{} + if err := rows.StructScan(csr); err != nil { + return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) + } + + csrs = append(csrs, *csr) + } + + q = fmt.Sprintf(`SELECT COUNT(*) FROM certs %s LIMIT :limit OFFSET :offset`, condition) + pm.Total, err = repo.total(ctx, q, params) + if err != nil { + return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) + } + return certs.CSRPage{ + PageMetadata: pm, + CSRs: csrs, + }, nil +} + +func (repo CSRRepo) total(ctx context.Context, query string, params interface{}) (uint64, error) { + rows, err := repo.db.NamedQueryContext(ctx, query, params) + if err != nil { + return 0, err + } + defer rows.Close() + total := uint64(0) + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + return total, nil +} + +func handleError(wrapper, err error) error { + pqErr, ok := err.(*pgconn.PgError) + if ok { + switch pqErr.Code { + case errDuplicate: + return errors.Wrap(ErrConflict, err) + case errInvalid, errInvalidChar, errTruncation, errUntranslatable: + return errors.Wrap(ErrMalformedEntity, err) + case errFK: + return errors.Wrap(ErrCreateEntity, err) + } + } + + return errors.Wrap(wrapper, err) +} diff --git a/postgres/csr/init.go b/postgres/csr/init.go index 9f52b7a..ab5f113 100644 --- a/postgres/csr/init.go +++ b/postgres/csr/init.go @@ -12,21 +12,21 @@ func Migration() *migrate.MemoryMigrationSource { return &migrate.MemoryMigrationSource{ Migrations: []*migrate.Migration{ { - Id: "certs_1", + Id: "csr_1", Up: []string{ - `CREATE TABLE IF NOT EXISTS certs ( - serial_number VARCHAR(40) UNIQUE NOT NULL, - certificate TEXT, - key TEXT, - revoked BOOLEAN, - expiry_time TIMESTAMP, - entity_id VARCHAR(36), - type TEXT CHECK (type IN ('RootCA', 'IntermediateCA', 'ClientCert')), - PRIMARY KEY (serial_number) + `CREATE TABLE IF NOT EXISTS csr ( + serial_number VARCHAR(40), + csr TEXT, + private_key TEXT, + entity_id VARCHAR(36), + status BOOLEAN, + submitted_at TIMESTAMP, + processed_at TIMESTAMP, + PRIMARY KEY (entity_id) )`, }, Down: []string{ - "DROP TABLE certs", + "DROP TABLE csr", }, }, }, diff --git a/service.go b/service.go index c1dc667..91b81c2 100644 --- a/service.go +++ b/service.go @@ -124,16 +124,18 @@ type SubjectOptions struct { type service struct { repo Repository + csrRepo CSRRepository rootCA *CA intermediateCA *CA } var _ Service = (*service)(nil) -func NewService(ctx context.Context, repo Repository, config *Config) (Service, error) { +func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, config *Config) (Service, error) { var svc service svc.repo = repo + svc.csrRepo = csrRepo if err := svc.loadCACerts(ctx); err != nil { return &svc, err } @@ -159,12 +161,17 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) { - privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - if err != nil { - return Certificate{}, err +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { + var err error + privKey := &rsa.PrivateKey{} + if len(key) == 0 { + privKey, err = rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return Certificate{}, err + } else { + privKey = key[0] + } } - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return Certificate{}, err @@ -469,6 +476,115 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er return s.getConcatCAs(ctx) } +func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string) (CSR, error) { + privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: metadata.CommonName, + Organization: metadata.Organization, + OrganizationalUnit: metadata.OrganizationalUnit, + Country: metadata.Country, + Province: metadata.Province, + Locality: metadata.Locality, + StreetAddress: metadata.StreetAddress, + PostalCode: metadata.PostalCode, + }, + DNSNames: metadata.DNSNames, + } + + for _, ip := range metadata.IPAddresses { + parsedIP := net.ParseIP(ip) + if parsedIP != nil { + template.IPAddresses = append(template.IPAddresses, parsedIP) + } + } + + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + csrPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + + privKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }) + + csr := CSR{ + CSR: csrPEM, + PrivateKey: privKeyPEM, + EntityID: entityID, + Status: "pending", + SubmittedAt: time.Now(), + } + + if err := s.csrRepo.CreateCSR(ctx, csr); err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + return csr, nil +} + +func (s *service) ProcessCSR(ctx context.Context, csrID string, approve bool) error { + csr, err := s.csrRepo.RetrieveCSR(ctx, csrID) + if err != nil { + return errors.Wrap(ErrViewEntity, err) + } + + if !approve { + csr.Status = "rejected" + csr.ProcessedAt = time.Now() + return s.csrRepo.UpdateCSR(ctx, csr) + } + + block, _ := pem.Decode(csr.CSR) + if block == nil { + return errors.New("failed to parse CSR PEM") + } + + parsedCSR, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + if err := parsedCSR.CheckSignature(); err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + privKey, err := extractPrivateKey(csr.PrivateKey) + if err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + cert, err := s.IssueCert(ctx, csr.EntityID, "", nil, SubjectOptions{ + CommonName: parsedCSR.Subject.CommonName, + Organization: parsedCSR.Subject.Organization, + OrganizationalUnit: parsedCSR.Subject.OrganizationalUnit, + Country: parsedCSR.Subject.Country, + Province: parsedCSR.Subject.Province, + Locality: parsedCSR.Subject.Locality, + StreetAddress: parsedCSR.Subject.StreetAddress, + PostalCode: parsedCSR.Subject.PostalCode, + }, privKey) + if err != nil { + return errors.Wrap(ErrCreateEntity, err) + } + + csr.Status = "approved" + csr.ProcessedAt = time.Now() + csr.SerialNumber = cert.SerialNumber + + return s.csrRepo.UpdateCSR(ctx, csr) +} + func (s *service) getConcatCAs(ctx context.Context) (Certificate, error) { intermediateCert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) if err != nil { @@ -791,3 +907,17 @@ func (s *service) loadCACerts(ctx context.Context) error { } return nil } + +func extractPrivateKey(pemKey []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemKey) + if block == nil { + return nil, errors.New("failed to parse private key PEM") + } + + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + return privKey, nil +} From 9bb2204ab85898b9ba74d32c81c4dd9541c2b7ba Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 21 Nov 2024 19:55:14 +0300 Subject: [PATCH 03/10] Add endpoints Signed-off-by: nyagamunene --- api/http/endpoint.go | 73 +++++++++++ api/http/errors.go | 3 - api/http/requests.go | 52 +++++++- api/http/responses.go | 70 +++++++++++ api/http/transport.go | 26 ++++ api/logging.go | 53 +++++++- api/metrics.go | 37 +++++- certs.go | 86 ++++++++++++- cmd/certs/main.go | 8 +- mocks/service.go | 281 +++++++++++++++++++++++++++++++++++++++--- postgres/csr/csr.go | 6 +- postgres/csr/init.go | 4 +- service.go | 113 ++++++----------- tracing/certs.go | 29 ++++- 14 files changed, 729 insertions(+), 112 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 2f43f13..409afb2 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -309,3 +309,76 @@ func viewCAEndpoint(svc certs.Service) endpoint.Endpoint { }, nil } } + +func createCSREndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(createCSRReq) + if err := req.validate(); err != nil { + return createCSRRes{created: false}, err + } + + csr, err := svc.CreateCSR(ctx, req.metadata, req.entityID, req.privKey) + if err != nil { + return createCSRRes{created: false}, err + } + + return createCSRRes{ + created: true, + CSR: csr, + }, nil + } +} + +func processCSREndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(processCSRReq) + if err := req.validate(); err != nil { + return processCSRRes{processed: false}, err + } + + err = svc.ProcessCSR(ctx, req.csrID, req.approve) + if err != nil { + return processCSRRes{processed: false}, err + } + + return processCSRRes{ + processed: true, + }, nil + } +} + +func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(listCSRsReq) + if err := req.validate(); err != nil { + return listCSRsRes{}, err + } + + cp, err := svc.ListCSRs(ctx, req.entityID, req.status) + if err != nil { + return listCSRsRes{}, err + } + + return listCSRsRes{ + cp, + }, nil + } +} + +func retrieveCSREndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(retrieveCSRReq) + if err := req.validate(); err != nil { + return retrieveCSRRes{}, err + } + + csr, err := svc.RetrieveCSR(ctx, req.csrID) + if err != nil { + return retrieveCSRRes{}, err + } + + return retrieveCSRRes{ + CSR: csr, + }, nil + } +} diff --git a/api/http/errors.go b/api/http/errors.go index bdcff3d..e4123dc 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -32,7 +32,4 @@ var ( // ErrMissingCN indicates missing common name. ErrMissingCN = errors.New("missing common name") - - // ErrEmptyEntityID indicates that the entity id is empty. - ErrEmptyEntityID = errors.New("missing entity id") ) diff --git a/api/http/requests.go b/api/http/requests.go index 4031b30..b59e106 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -4,6 +4,8 @@ package http import ( + "crypto/rsa" + "github.com/absmach/certs" "github.com/absmach/certs/errors" "golang.org/x/crypto/ocsp" @@ -38,7 +40,7 @@ type deleteReq struct { func (req deleteReq) validate() error { if req.entityID == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrEmptyEntityID) + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } return nil } @@ -87,3 +89,51 @@ func (req ocspReq) validate() error { } return nil } + +type createCSRReq struct { + metadata certs.CSRMetadata + entityID string + privKey *rsa.PrivateKey +} + +func (req createCSRReq) validate() error { + if req.entityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + return nil +} + +type processCSRReq struct { + csrID string + approve bool +} + +func (req processCSRReq) validate() error { + if req.csrID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + return nil +} + +type listCSRsReq struct { + entityID string + status string +} + +func (req listCSRsReq) validate() error { + if req.entityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + return nil +} + +type retrieveCSRReq struct { + csrID string +} + +func (req retrieveCSRReq) validate() error { + if req.csrID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + return nil +} diff --git a/api/http/responses.go b/api/http/responses.go index f53ce6c..1f708a4 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -9,6 +9,7 @@ import ( "net/http" "time" + "github.com/absmach/certs" "golang.org/x/crypto/ocsp" ) @@ -201,3 +202,72 @@ type fileDownloadRes struct { Filename string ContentType string } + +type createCSRRes struct { + certs.CSR + created bool +} + +func (res createCSRRes) Code() int { + if res.created { + return http.StatusCreated + } + + return http.StatusOK +} + +func (res createCSRRes) Headers() map[string]string { + return map[string]string{} +} + +func (res createCSRRes) Empty() bool { + return false +} + +type processCSRRes struct { + processed bool +} + +func (res processCSRRes) Code() int { + return http.StatusOK +} + +func (res processCSRRes) Headers() map[string]string { + return map[string]string{} +} + +func (res processCSRRes) Empty() bool { + return true +} + +type listCSRsRes struct { + certs.CSRPage +} + +func (res listCSRsRes) Code() int { + return http.StatusOK +} + +func (res listCSRsRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listCSRsRes) Empty() bool { + return false +} + +type retrieveCSRRes struct { + certs.CSR +} + +func (res retrieveCSRRes) Code() int { + return http.StatusOK +} + +func (res retrieveCSRRes) Headers() map[string]string { + return map[string]string{} +} + +func (res retrieveCSRRes) Empty() bool { + return false +} diff --git a/api/http/transport.go b/api/http/transport.go index 1d0df55..79719cc 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -137,6 +137,32 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http encodeCADownloadResponse, opts..., ), "download_ca").ServeHTTP) + r.Route("/csr", func(r chi.Router) { + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + createCSREndpoint(svc), + decodeDownloadCA, + encodeCADownloadResponse, + opts..., + ), "").ServeHTTP) + r.Patch("/", otelhttp.NewHandler(kithttp.NewServer( + processCSREndpoint(svc), + decodeDownloadCA, + encodeCADownloadResponse, + opts..., + ), "").ServeHTTP) + r.Get("/retrieve/{id}", otelhttp.NewHandler(kithttp.NewServer( + retrieveCSREndpoint(svc), + decodeDownloadCA, + encodeCADownloadResponse, + opts..., + ), "").ServeHTTP) + r.Get("/list", otelhttp.NewHandler(kithttp.NewServer( + listCSRsEndpoint(svc), + decodeDownloadCA, + encodeCADownloadResponse, + opts..., + ), "").ServeHTTP) + }) }) r.Get("/health", certs.Health("certs", instanceID)) diff --git a/api/logging.go b/api/logging.go index 731784d..05852e1 100644 --- a/api/logging.go +++ b/api/logging.go @@ -5,6 +5,7 @@ package api import ( "context" + "crypto/rsa" "crypto/x509" "fmt" "log/slog" @@ -85,7 +86,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { @@ -94,7 +95,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string } lm.logger.Info(message) }(time.Now()) - return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) + return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) } func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) { @@ -180,3 +181,51 @@ func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert }(time.Now()) return lm.svc.GetChainCA(ctx, token) } + +func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (csr certs.CSR, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method create_csr took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.CreateCSR(ctx, meta, entityID, key...) +} + +func (lm *loggingMiddleware) ProcessCSR(ctx context.Context, csrID string, approve bool) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method process_csr took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.ProcessCSR(ctx, csrID, approve) +} + +func (lm *loggingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (cp certs.CSRPage, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method list_csrs took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.ListCSRs(ctx, entityID, status) +} + +func (lm *loggingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (csr certs.CSR, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method retrieve_csr took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.RetrieveCSR(ctx, csrID) +} diff --git a/api/metrics.go b/api/metrics.go index a5506c3..1aa5e7b 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -5,6 +5,7 @@ package api import ( "context" + "crypto/rsa" "crypto/x509" "time" @@ -71,12 +72,12 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) + return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) } func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { @@ -135,3 +136,35 @@ func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (cert }(time.Now()) return mm.svc.GetChainCA(ctx, token) } + +func (mm *metricsMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (certs.CSR, error) { + defer func(begin time.Time) { + mm.counter.With("method", "create_csr").Add(1) + mm.latency.With("method", "create_csr").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.CreateCSR(ctx, meta, entityID, key...) +} + +func (mm *metricsMiddleware) ProcessCSR(ctx context.Context, csrID string, approve bool) error { + defer func(begin time.Time) { + mm.counter.With("method", "process_csr").Add(1) + mm.latency.With("method", "process_csr").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.ProcessCSR(ctx, csrID, approve) +} + +func (mm *metricsMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_csrs").Add(1) + mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.ListCSRs(ctx, entityID, status) +} + +func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { + defer func(begin time.Time) { + mm.counter.With("method", "retrieve_csr").Add(1) + mm.latency.With("method", "retrieve_csr").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RetrieveCSR(ctx, csrID) +} diff --git a/certs.go b/certs.go index c1c221f..6ae4906 100644 --- a/certs.go +++ b/certs.go @@ -5,10 +5,62 @@ package certs import ( "context" + "crypto/rsa" "crypto/x509" + "net" "time" + + "github.com/absmach/certs/errors" +) + +type CertType int + +const ( + RootCA CertType = iota + IntermediateCA + ClientCert +) + +const ( + Root = "RootCA" + Inter = "IntermediateCA" + Client = "ClientCert" + Unknown = "Unknown" ) +func (c CertType) String() string { + switch c { + case RootCA: + return Root + case IntermediateCA: + return Inter + case ClientCert: + return Client + default: + return Unknown + } +} + +func CertTypeFromString(s string) (CertType, error) { + switch s { + case Root: + return RootCA, nil + case Inter: + return IntermediateCA, nil + case Client: + return ClientCert, nil + default: + return -1, errors.New("unknown cert type") + } +} + +type CA struct { + Type CertType + Certificate *x509.Certificate + PrivateKey *rsa.PrivateKey + SerialNumber string +} + type Certificate struct { SerialNumber string `db:"serial_number"` Certificate []byte `db:"certificate"` @@ -30,6 +82,7 @@ type PageMetadata struct { Offset uint64 `json:"offset,omitempty" db:"offset"` Limit uint64 `json:"limit,omitempty" db:"limit"` EntityID string `json:"entity_id,omitempty" db:"entity_id"` + Status string `json:"status,omitempty" db:"status"` } type CSRMetadata struct { @@ -44,9 +97,11 @@ type CSRMetadata struct { EmailAddress string `json:"email_address"` DNSNames []string `json:"dns_names"` IPAddresses []string `json:"ip_addresses"` + EmailAddresses []string `json:"email_addresses"` } type CSR struct { + ID string `json:"id" db:"id"` CSR []byte `json:"csr" db:"csr"` PrivateKey []byte `json:"private_key" db:"private_key"` EntityID string `json:"entity_id" db:"entity_id"` @@ -61,6 +116,31 @@ type CSRPage struct { CSRs []CSR } +type SubjectOptions struct { + CommonName string + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` +} + +type Config struct { + CommonName string `yaml:"common_name"` + Organization []string `yaml:"organization"` + OrganizationalUnit []string `yaml:"organizational_unit"` + Country []string `yaml:"country"` + Province []string `yaml:"province"` + Locality []string `yaml:"locality"` + StreetAddress []string `yaml:"street_address"` + PostalCode []string `yaml:"postal_code"` + DNSNames []string `yaml:"dns_names"` + IPAddresses []net.IP `yaml:"ip_addresses"` + ValidityPeriod string `yaml:"validity_period"` +} + type Service interface { // RenewCert renews a certificate from the database. RenewCert(ctx context.Context, serialNumber string) error @@ -86,7 +166,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...*rsa.PrivateKey) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) @@ -104,13 +184,13 @@ type Service interface { RemoveCert(ctx context.Context, entityId string) error // CreateCSR creates a new Certificate Signing Request - CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string) (CSR, error) + CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string, privKey ...*rsa.PrivateKey) (CSR, error) // ProcessCSR processes a pending CSR and either approves or rejects it ProcessCSR(ctx context.Context, csrID string, approve bool) error // ListCSRs returns a list of CSRs based on filter criteria - ListCSRs(ctx context.Context, entityID string, status string) ([]CSR, error) + ListCSRs(ctx context.Context, entityID string, status string) (CSRPage, error) // RetrieveCSR retrieves a specific CSR by ID RetrieveCSR(ctx context.Context, csrID string) (CSR, error) diff --git a/cmd/certs/main.go b/cmd/certs/main.go index 6047ac1..c174da5 100644 --- a/cmd/certs/main.go +++ b/cmd/certs/main.go @@ -79,7 +79,10 @@ func main() { if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefix}); err != nil { logger.Error(err.Error()) } - db, err := pgClient.Setup(dbConfig, *cpostgres.Migration()) + cm := cpostgres.Migration() + sm := csrpostgres.Migration() + cm.Migrations = append(cm.Migrations, sm.Migrations...) + db, err := pgClient.Setup(dbConfig, *cm) if err != nil { log.Fatalf(fmt.Sprintf("Failed to connect to %s database: %s", svcName, err)) } @@ -148,7 +151,8 @@ func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, logger *s database := postgres.NewDatabase(db, dbConfig, tracer) repo := cpostgres.NewRepository(database) csrRepo := csrpostgres.NewRepository(database) - svc, err := certs.NewService(ctx, repo, csrRepo, config) + idp := uuid.New() + svc, err := certs.NewService(ctx, repo, csrRepo, config, idp) if err != nil { return nil, err } diff --git a/mocks/service.go b/mocks/service.go index aa91db8..deb6ff8 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -12,6 +12,8 @@ import ( mock "github.com/stretchr/testify/mock" + rsa "crypto/rsa" + x509 "crypto/x509" ) @@ -28,6 +30,79 @@ func (_m *MockService) EXPECT() *MockService_Expecter { return &MockService_Expecter{mock: &_m.Mock} } +// CreateCSR provides a mock function with given fields: ctx, metadata, entityID, privKey +func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, entityID string, privKey ...*rsa.PrivateKey) (certs.CSR, error) { + _va := make([]interface{}, len(privKey)) + for _i := range privKey { + _va[_i] = privKey[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, metadata, entityID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CreateCSR") + } + + var r0 certs.CSR + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) (certs.CSR, error)); ok { + return rf(ctx, metadata, entityID, privKey...) + } + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) certs.CSR); ok { + r0 = rf(ctx, metadata, entityID, privKey...) + } else { + r0 = ret.Get(0).(certs.CSR) + } + + if rf, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) error); ok { + r1 = rf(ctx, metadata, entityID, privKey...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' +type MockService_CreateCSR_Call struct { + *mock.Call +} + +// CreateCSR is a helper method to define mock.On call +// - ctx context.Context +// - metadata certs.CSRMetadata +// - entityID string +// - privKey ...*rsa.PrivateKey +func (_e *MockService_Expecter) CreateCSR(ctx interface{}, metadata interface{}, entityID interface{}, privKey ...interface{}) *MockService_CreateCSR_Call { + return &MockService_CreateCSR_Call{Call: _e.mock.On("CreateCSR", + append([]interface{}{ctx, metadata, entityID}, privKey...)...)} +} + +func (_c *MockService_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, entityID string, privKey ...*rsa.PrivateKey)) *MockService_CreateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]*rsa.PrivateKey, len(args)-3) + for i, a := range args[3:] { + if a != nil { + variadicArgs[i] = a.(*rsa.PrivateKey) + } + } + run(args[0].(context.Context), args[1].(certs.CSRMetadata), args[2].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockService_CreateCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockService_CreateCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) (certs.CSR, error)) *MockService_CreateCSR_Call { + _c.Call.Return(run) + return _c +} + // GenerateCRL provides a mock function with given fields: ctx, caType func (_m *MockService) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) { ret := _m.Called(ctx, caType) @@ -201,9 +276,16 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s return _c } -// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error) { - ret := _m.Called(ctx, entityID, ttl, ipAddrs, option) +// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { + _va := make([]interface{}, len(privKey)) + for _i := range privKey { + _va[_i] = privKey[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, entityID, ttl, ipAddrs, option) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) if len(ret) == 0 { panic("no return value specified for IssueCert") @@ -211,17 +293,17 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)); ok { - return rf(ctx, entityID, ttl, ipAddrs, option) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)); ok { + return rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) certs.Certificate); ok { - r0 = rf(ctx, entityID, ttl, ipAddrs, option) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) certs.Certificate); ok { + r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions) error); ok { - r1 = rf(ctx, entityID, ttl, ipAddrs, option) + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) error); ok { + r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r1 = ret.Error(1) } @@ -240,13 +322,21 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}) *MockService_IssueCert_Call { - return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option)} +// - privKey ...*rsa.PrivateKey +func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey ...interface{}) *MockService_IssueCert_Call { + return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", + append([]interface{}{ctx, entityID, ttl, ipAddrs, option}, privKey...)...)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey)) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions)) + variadicArgs := make([]*rsa.PrivateKey, len(args)-5) + for i, a := range args[5:] { + if a != nil { + variadicArgs[i] = a.(*rsa.PrivateKey) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), variadicArgs...) }) return _c } @@ -256,7 +346,65 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { + _c.Call.Return(run) + return _c +} + +// ListCSRs provides a mock function with given fields: ctx, entityID, status +func (_m *MockService) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { + ret := _m.Called(ctx, entityID, status) + + if len(ret) == 0 { + panic("no return value specified for ListCSRs") + } + + var r0 certs.CSRPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.CSRPage, error)); ok { + return rf(ctx, entityID, status) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.CSRPage); ok { + r0 = rf(ctx, entityID, status) + } else { + r0 = ret.Get(0).(certs.CSRPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, entityID, status) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' +type MockService_ListCSRs_Call struct { + *mock.Call +} + +// ListCSRs is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - status string +func (_e *MockService_Expecter) ListCSRs(ctx interface{}, entityID interface{}, status interface{}) *MockService_ListCSRs_Call { + return &MockService_ListCSRs_Call{Call: _e.mock.On("ListCSRs", ctx, entityID, status)} +} + +func (_c *MockService_ListCSRs_Call) Run(run func(ctx context.Context, entityID string, status string)) *MockService_ListCSRs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockService_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockService_ListCSRs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_ListCSRs_Call) RunAndReturn(run func(context.Context, string, string) (certs.CSRPage, error)) *MockService_ListCSRs_Call { _c.Call.Return(run) return _c } @@ -393,6 +541,54 @@ func (_c *MockService_OCSP_Call) RunAndReturn(run func(context.Context, string) return _c } +// ProcessCSR provides a mock function with given fields: ctx, csrID, approve +func (_m *MockService) ProcessCSR(ctx context.Context, csrID string, approve bool) error { + ret := _m.Called(ctx, csrID, approve) + + if len(ret) == 0 { + panic("no return value specified for ProcessCSR") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok { + r0 = rf(ctx, csrID, approve) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockService_ProcessCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessCSR' +type MockService_ProcessCSR_Call struct { + *mock.Call +} + +// ProcessCSR is a helper method to define mock.On call +// - ctx context.Context +// - csrID string +// - approve bool +func (_e *MockService_Expecter) ProcessCSR(ctx interface{}, csrID interface{}, approve interface{}) *MockService_ProcessCSR_Call { + return &MockService_ProcessCSR_Call{Call: _e.mock.On("ProcessCSR", ctx, csrID, approve)} +} + +func (_c *MockService_ProcessCSR_Call) Run(run func(ctx context.Context, csrID string, approve bool)) *MockService_ProcessCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(bool)) + }) + return _c +} + +func (_c *MockService_ProcessCSR_Call) Return(_a0 error) *MockService_ProcessCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockService_ProcessCSR_Call) RunAndReturn(run func(context.Context, string, bool) error) *MockService_ProcessCSR_Call { + _c.Call.Return(run) + return _c +} + // RemoveCert provides a mock function with given fields: ctx, entityId func (_m *MockService) RemoveCert(ctx context.Context, entityId string) error { ret := _m.Called(ctx, entityId) @@ -543,6 +739,63 @@ func (_c *MockService_RetrieveCAToken_Call) RunAndReturn(run func(context.Contex return _c } +// RetrieveCSR provides a mock function with given fields: ctx, csrID +func (_m *MockService) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { + ret := _m.Called(ctx, csrID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCSR") + } + + var r0 certs.CSR + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (certs.CSR, error)); ok { + return rf(ctx, csrID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) certs.CSR); ok { + r0 = rf(ctx, csrID) + } else { + r0 = ret.Get(0).(certs.CSR) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, csrID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' +type MockService_RetrieveCSR_Call struct { + *mock.Call +} + +// RetrieveCSR is a helper method to define mock.On call +// - ctx context.Context +// - csrID string +func (_e *MockService_Expecter) RetrieveCSR(ctx interface{}, csrID interface{}) *MockService_RetrieveCSR_Call { + return &MockService_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", ctx, csrID)} +} + +func (_c *MockService_RetrieveCSR_Call) Run(run func(ctx context.Context, csrID string)) *MockService_RetrieveCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockService_RetrieveCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockService_RetrieveCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_RetrieveCSR_Call) RunAndReturn(run func(context.Context, string) (certs.CSR, error)) *MockService_RetrieveCSR_Call { + _c.Call.Return(run) + return _c +} + // RetrieveCert provides a mock function with given fields: ctx, token, serialNumber func (_m *MockService) RetrieveCert(ctx context.Context, token string, serialNumber string) (certs.Certificate, []byte, error) { ret := _m.Called(ctx, token, serialNumber) diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index f4d52dc..5f92c8e 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -43,8 +43,8 @@ func NewRepository(db postgres.Database) certs.CSRRepository { func (repo CSRRepo) CreateCSR(ctx context.Context, cert certs.CSR) error { q := ` - INSERT INTO certs (serial_number, csr, private_key, entity_id, status, submitted_at, processed_at) - VALUES (:serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :processed_at)` + INSERT INTO certs (id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at) + VALUES (:id, :serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :processed_at)` _, err := repo.db.NamedExecContext(ctx, q, cert) if err != nil { return handleError(certs.ErrCreateEntity, err) @@ -69,7 +69,7 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { } func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error) { - q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM certs WHERE serial_number = $1` + q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM certs WHERE id = $1` var csr certs.CSR if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csr); err != nil { if err == sql.ErrNoRows { diff --git a/postgres/csr/init.go b/postgres/csr/init.go index ab5f113..0f1534b 100644 --- a/postgres/csr/init.go +++ b/postgres/csr/init.go @@ -15,14 +15,14 @@ func Migration() *migrate.MemoryMigrationSource { Id: "csr_1", Up: []string{ `CREATE TABLE IF NOT EXISTS csr ( + id VARCHAR(36) PRIMARY KEY, serial_number VARCHAR(40), csr TEXT, private_key TEXT, entity_id VARCHAR(36), status BOOLEAN, submitted_at TIMESTAMP, - processed_at TIMESTAMP, - PRIMARY KEY (entity_id) + processed_at TIMESTAMP )`, }, Down: []string{ diff --git a/service.go b/service.go index 91b81c2..bd9f25e 100644 --- a/service.go +++ b/service.go @@ -16,6 +16,7 @@ import ( "time" "github.com/absmach/certs/errors" + "github.com/absmach/certs/internal/uuid" "github.com/golang-jwt/jwt" "golang.org/x/crypto/ocsp" ) @@ -32,68 +33,6 @@ const ( downloadTokenExpiry = time.Minute * 5 ) -type CertType int - -const ( - RootCA CertType = iota - IntermediateCA - ClientCert -) - -const ( - Root = "RootCA" - Inter = "IntermediateCA" - Client = "ClientCert" - Unknown = "Unknown" -) - -func (c CertType) String() string { - switch c { - case RootCA: - return Root - case IntermediateCA: - return Inter - case ClientCert: - return Client - default: - return Unknown - } -} - -func CertTypeFromString(s string) (CertType, error) { - switch s { - case Root: - return RootCA, nil - case Inter: - return IntermediateCA, nil - case Client: - return ClientCert, nil - default: - return -1, errors.New("unknown cert type") - } -} - -type CA struct { - Type CertType - Certificate *x509.Certificate - PrivateKey *rsa.PrivateKey - SerialNumber string -} - -type Config struct { - CommonName string `yaml:"common_name"` - Organization []string `yaml:"organization"` - OrganizationalUnit []string `yaml:"organizational_unit"` - Country []string `yaml:"country"` - Province []string `yaml:"province"` - Locality []string `yaml:"locality"` - StreetAddress []string `yaml:"street_address"` - PostalCode []string `yaml:"postal_code"` - DNSNames []string `yaml:"dns_names"` - IPAddresses []net.IP `yaml:"ip_addresses"` - ValidityPeriod string `yaml:"validity_period"` -} - var ( serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 128) ErrNotFound = errors.New("entity not found") @@ -111,31 +50,22 @@ var ( ErrInvalidLength = errors.New("invalid length of serial numbers") ) -type SubjectOptions struct { - CommonName string - Organization []string `json:"organization"` - OrganizationalUnit []string `json:"organizational_unit"` - Country []string `json:"country"` - Province []string `json:"province"` - Locality []string `json:"locality"` - StreetAddress []string `json:"street_address"` - PostalCode []string `json:"postal_code"` -} - type service struct { repo Repository csrRepo CSRRepository rootCA *CA intermediateCA *CA + idProvider uuid.IDProvider } var _ Service = (*service)(nil) -func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, config *Config) (Service, error) { +func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, config *Config, idp uuid.IDProvider) (Service, error) { var svc service svc.repo = repo svc.csrRepo = csrRepo + svc.idProvider = idp if err := svc.loadCACerts(ctx); err != nil { return &svc, err } @@ -476,10 +406,23 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er return s.getConcatCAs(ctx) } -func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string) (CSR, error) { - privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) +func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string, privateKey ...*rsa.PrivateKey) (CSR, error) { + var privKey *rsa.PrivateKey + var err error + + // Check if a private key is provided else generate a new private key. + if len(privateKey) > 0 && privateKey[0] != nil { + privKey = privateKey[0] + } else { + privKey, err = rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + } + + csrID, err := s.idProvider.ID() if err != nil { - return CSR{}, errors.Wrap(ErrCreateEntity, err) + return CSR{}, err } template := &x509.CertificateRequest{ @@ -493,7 +436,8 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID StreetAddress: metadata.StreetAddress, PostalCode: metadata.PostalCode, }, - DNSNames: metadata.DNSNames, + EmailAddresses: metadata.EmailAddresses, + DNSNames: metadata.DNSNames, } for _, ip := range metadata.IPAddresses { @@ -519,6 +463,7 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID }) csr := CSR{ + ID: csrID, CSR: csrPEM, PrivateKey: privKeyPEM, EntityID: entityID, @@ -585,6 +530,18 @@ func (s *service) ProcessCSR(ctx context.Context, csrID string, approve bool) er return s.csrRepo.UpdateCSR(ctx, csr) } +func (s *service) ListCSRs(ctx context.Context, entityID string, status string) (CSRPage, error) { + pm := PageMetadata{ + EntityID: entityID, + Status: status, + } + return s.csrRepo.ListCSRs(ctx, pm) +} + +func (s *service) RetrieveCSR(ctx context.Context, csrID string) (CSR, error) { + return s.csrRepo.RetrieveCSR(ctx, csrID) +} + func (s *service) getConcatCAs(ctx context.Context) (Certificate, error) { intermediateCert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) if err != nil { diff --git a/tracing/certs.go b/tracing/certs.go index efac814..9a10765 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -5,6 +5,7 @@ package tracing import ( "context" + "crypto/rsa" "crypto/x509" "github.com/absmach/certs" @@ -53,10 +54,10 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() - return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) + return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) } func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { @@ -100,3 +101,27 @@ func (tm *tracingMiddleware) GetChainCA(ctx context.Context, token string) (cert defer span.End() return tm.svc.GetChainCA(ctx, token) } + +func (tm *tracingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (certs.CSR, error) { + ctx, span := tm.tracer.Start(ctx, "create_csr") + defer span.End() + return tm.svc.CreateCSR(ctx, meta, entityID, key...) +} + +func (tm *tracingMiddleware) ProcessCSR(ctx context.Context, csrID string, approve bool) error { + ctx, span := tm.tracer.Start(ctx, "process_csr") + defer span.End() + return tm.svc.ProcessCSR(ctx, csrID, approve) +} + +func (tm *tracingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { + ctx, span := tm.tracer.Start(ctx, "list_csrs") + defer span.End() + return tm.svc.ListCSRs(ctx, entityID, status) +} + +func (tm *tracingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { + ctx, span := tm.tracer.Start(ctx, "retrieve_csr") + defer span.End() + return tm.svc.RetrieveCSR(ctx, csrID) +} From 74ef63f0497938b30c4fba21cdc125fe57ef5c79 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 27 Nov 2024 00:31:14 +0300 Subject: [PATCH 04/10] add sdk support Signed-off-by: nyagamunene --- api/http/transport.go | 172 +++++++++++++++++++++++++++++++++++++++--- certs.go | 1 - sdk/sdk.go | 170 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 326 insertions(+), 17 deletions(-) diff --git a/api/http/transport.go b/api/http/transport.go index 79719cc..ab949f9 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -32,6 +32,18 @@ const ( limitKey = "limit" entityKey = "entity_id" commonName = "common_name" + organization = "organization" + orgUnit = "organization_unit" + country = "country" + province = "province" + locality = "locality" + streetAddress = "street_address" + postalCode = "postal_code" + emailAddresses = "email_addresses" + dnsNames = "dns_names" + ipAddresses = "ip_addresses" + approve = "approve" + status = "status" token = "token" ocspStatusParam = "force_status" entityIDParam = "entityID" @@ -140,26 +152,26 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http r.Route("/csr", func(r chi.Router) { r.Post("/", otelhttp.NewHandler(kithttp.NewServer( createCSREndpoint(svc), - decodeDownloadCA, - encodeCADownloadResponse, + decodeCreateCSR, + EncodeResponse, opts..., ), "").ServeHTTP) - r.Patch("/", otelhttp.NewHandler(kithttp.NewServer( + r.Patch("/{id}", otelhttp.NewHandler(kithttp.NewServer( processCSREndpoint(svc), - decodeDownloadCA, - encodeCADownloadResponse, + decodeUpdateCSR, + EncodeResponse, opts..., ), "").ServeHTTP) - r.Get("/retrieve/{id}", otelhttp.NewHandler(kithttp.NewServer( + r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer( retrieveCSREndpoint(svc), - decodeDownloadCA, - encodeCADownloadResponse, + decodeRetrieveCSR, + EncodeResponse, opts..., ), "").ServeHTTP) r.Get("/list", otelhttp.NewHandler(kithttp.NewServer( listCSRsEndpoint(svc), - decodeDownloadCA, - encodeCADownloadResponse, + decodeListCSR, + EncodeResponse, opts..., ), "").ServeHTTP) }) @@ -287,6 +299,128 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } +func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { + o, err := readStringQuery(r, organization, "") + if err != nil { + return nil, err + } + + ou, err := readStringQuery(r, orgUnit, "") + if err != nil { + return nil, err + } + + c, err := readStringQuery(r, country, "") + if err != nil { + return nil, err + } + + p, err := readStringQuery(r, province, "") + if err != nil { + return nil, err + } + + l, err := readStringQuery(r, locality, "") + if err != nil { + return nil, err + } + + s, err := readStringQuery(r, streetAddress, "") + if err != nil { + return nil, err + } + + po, err := readStringQuery(r, postalCode, "") + if err != nil { + return nil, err + } + + entity, err := readStringQuery(r, entityKey, "") + if err != nil { + return nil, err + } + + cn, err := readStringQuery(r, commonName, "") + if err != nil { + return nil, err + } + + e, err := readStringQuery(r, emailAddresses, "") + if err != nil { + return nil, err + } + + d, err := readStringQuery(r, dnsNames, "") + if err != nil { + return nil, err + } + + i, err := readStringQuery(r, ipAddresses, "") + if err != nil { + return nil, err + } + + req := createCSRReq{ + metadata: certs.CSRMetadata{ + CommonName: cn, + Organization: []string{o}, + OrganizationalUnit: []string{ou}, + Country: []string{c}, + Province: []string{p}, + Locality: []string{l}, + StreetAddress: []string{s}, + PostalCode: []string{po}, + DNSNames: []string{d}, + IPAddresses: []string{i}, + EmailAddresses: []string{e}, + }, + entityID: entity, + // privKey: , + } + + return req, nil +} + +func decodeUpdateCSR(_ context.Context, r *http.Request) (interface{}, error) { + app, err := readBoolQuery(r, approve, false) + if err != nil { + return nil, err + } + + req := processCSRReq{ + csrID: chi.URLParam(r, "id"), + approve: app, + } + + return req, nil +} + +func decodeRetrieveCSR(_ context.Context, r *http.Request) (interface{}, error) { + req := retrieveCSRReq{ + csrID: chi.URLParam(r, "id"), + } + + return req, nil +} + +func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) { + s, err := readStringQuery(r, status, "all") + if err != nil { + return nil, err + } + e, err := readStringQuery(r, entityKey, "") + if err != nil { + return nil, err + } + + req := listCSRsReq{ + entityID: e, + status: s, + } + + return req, nil +} + // EncodeResponse encodes successful response. func EncodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { if ar, ok := response.(Response); ok { @@ -456,3 +590,21 @@ func readNumQuery(r *http.Request, key string, def uint64) (uint64, error) { } return v, nil } + +func readBoolQuery(r *http.Request, key string, def bool) (bool, error) { + vals := r.URL.Query()[key] + if len(vals) > 1 { + return false, ErrInvalidQueryParams + } + + if len(vals) == 0 { + return def, nil + } + + b, err := strconv.ParseBool(vals[0]) + if err != nil { + return false, errors.Wrap(ErrInvalidQueryParams, err) + } + + return b, nil +} diff --git a/certs.go b/certs.go index 6ae4906..c14fc76 100644 --- a/certs.go +++ b/certs.go @@ -94,7 +94,6 @@ type CSRMetadata struct { Locality []string `json:"locality"` StreetAddress []string `json:"street_address"` PostalCode []string `json:"postal_code"` - EmailAddress string `json:"email_address"` DNSNames []string `json:"dns_names"` IPAddresses []string `json:"ip_addresses"` EmailAddresses []string `json:"email_addresses"` diff --git a/sdk/sdk.go b/sdk/sdk.go index 9bf3ff3..ee30fb1 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -27,6 +27,7 @@ import ( const ( certsEndpoint = "certs" + csrEndpoint = "csr" issueCertEndpoint = "certs/issue" emptyOCSPbody = 22 ) @@ -75,12 +76,22 @@ func (c CertStatus) MarshalJSON() ([]byte, error) { } type PageMetadata struct { - Total uint64 `json:"total,omitempty"` - Offset uint64 `json:"offset,omitempty"` - Limit uint64 `json:"limit,omitempty"` - EntityID string `json:"entity_id,omitempty"` - Token string `json:"token,omitempty"` - CommonName string `json:"common_name,omitempty"` + Total uint64 `json:"total,omitempty"` + Offset uint64 `json:"offset,omitempty"` + Limit uint64 `json:"limit,omitempty"` + EntityID string `json:"entity_id,omitempty"` + Token string `json:"token,omitempty"` + CommonName string `json:"common_name,omitempty"` + Organization []string `json:"organization,omitempty"` + OrganizationalUnit []string `json:"organizational_unit,omitempty"` + Country []string `json:"country,omitempty"` + Province []string `json:"province,omitempty"` + Locality []string `json:"locality,omitempty"` + StreetAddress []string `json:"street_address,omitempty"` + PostalCode []string `json:"postal_code,omitempty"` + DNSNames []string `json:"dns_names,omitempty"` + IPAddresses []string `json:"ip_addresses,omitempty"` + EmailAddresses []string `json:"email_addresses,omitempty"` } type Options struct { @@ -148,6 +159,36 @@ type OCSPResponse struct { IssuerHash string `json:"issuer_hash,omitempty"` } +type CSRMetadata struct { + CommonName string `json:"common_name"` + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DNSNames []string `json:"dns_names"` + IPAddresses []string `json:"ip_addresses"` + EmailAddresses []string `json:"email_addresses"` +} + +type CSR struct { + ID string `json:"id,omitempty"` + CSR []byte `json:"csr,omitempty"` + PrivateKey []byte `json:"private_key,omitempty"` + EntityID string `json:"entity_id,omitempty"` + Status string `json:"status,omitempty"` + SubmittedAt time.Time `json:"submitted_at,omitempty"` + ProcessedAt time.Time `json:"processed_at,omitempty"` + SerialNumber string `json:"serial_number,omitempty"` +} + +type CSRPage struct { + PageMetadata + CSRs []CSR +} + type SDK interface { // IssueCert issues a certificate for a thing required for mTLS. // @@ -232,6 +273,33 @@ type SDK interface { // response, _ := sdk.GetCAToken() // fmt.Println(response) GetCAToken() (Token, errors.SDKError) + + // CreateCSR creates a new Certificate Signing Request + // + // example: + // pm = sdk.CSRMetadata{CommonName: "common_name", "entity_id" } + // reponse, _ := sdk.CreateCSR(pm, "privKeyPath") + // fmt.Println(response) + CreateCSR(metadata CSRMetadata, entityID string, privKeyPath string) (CSR, errors.SDKError) + + // SignCSR processes a pending CSR and either signs or rejects it + // + // example: + // err := sdk.SignCSR( "csr_id", "privKeyPath") + // fmt.Println(err) + SignCSR(csrID string, sign bool) errors.SDKError + + // ListCSRs returns a list of CSRs based on filter criteria + // + // reponse, _ := sdk.ListCSRs("entity_id", "pending") + // fmt.Println(response) + ListCSRs(entityID string, status string) (CSRPage, errors.SDKError) + + // RetrieveCSR retrieves a specific CSR by ID + // + // reponse, _ := sdk.RetrieveCSR("csr_id") + // fmt.Println(response) + RetrieveCSR(csrID string) (CSR, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -500,6 +568,83 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { return tk, nil } +func (sdk mgSDK) CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDKError) { + r := csrReq{ + Organization: pm.Organization, + OrganizationalUnit: pm.OrganizationalUnit, + Country: pm.Country, + Province: pm.Province, + Locality: pm.Locality, + StreetAddress: pm.StreetAddress, + PostalCode: pm.PostalCode, + DNSNames: pm.DNSNames, + IPAddresses: pm.IPAddresses, + EmailAddresses: pm.EmailAddresses, + } + d, err := json.Marshal(r) + if err != nil { + return CSR{}, errors.NewSDKError(err) + } + url := fmt.Sprintf("%s/%s", issueCertEndpoint, entityID) + + _, body, sdkerr := sdk.processRequest(http.MethodPost, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return CSR{}, sdkerr + } + + var csr CSR + if err := json.Unmarshal(body, &csr); err != nil { + return CSR{}, errors.NewSDKError(err) + } + return csr, nil +} + +func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/download-ca", certsEndpoint), pm) + if err != nil { + return errors.NewSDKError(err) + } + + _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return sdkerr + } + return nil +} + +func (sdk mgSDK) ListCSRs(entityID string, status string) (CSRPage, errors.SDKError) { + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/list", csrEndpoint), pm) + if err != nil { + return CSRPage{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return CSRPage{}, sdkerr + } + + var cp CSRPage + if err := json.Unmarshal(body, &cp); err != nil { + return CSRPage{}, errors.NewSDKError(err) + } + return cp, nil +} + +func (sdk mgSDK) RetrieveCSR(csrID string) (CSR, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, csrEndpoint, csrID) + + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return CSR{}, sdkerr + } + + var csr CSR + if err := json.Unmarshal(body, &csr); err != nil { + return CSR{}, errors.NewSDKError(err) + } + return csr, nil +} + func NewSDK(conf Config) SDK { return &mgSDK{ certsURL: conf.CertsURL, @@ -604,3 +749,16 @@ type certReq struct { TTL string `json:"ttl"` Options Options `json:"options"` } + +type csrReq struct { + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DNSNames []string `json:"dns_names"` + IPAddresses []string `json:"ip_addresses"` + EmailAddresses []string `json:"email_addresses"` +} From 7fd0e599ccdc80c280d5c0e1d126004a410ce8a2 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 27 Nov 2024 02:45:38 +0300 Subject: [PATCH 05/10] fix tests Signed-off-by: nyagamunene --- api/http/endpoint.go | 14 +-- api/http/requests.go | 15 ++- api/http/responses.go | 10 +- api/http/transport.go | 91 +-------------- api/logging.go | 4 +- api/metrics.go | 4 +- certs.go | 5 +- certs_test.go | 32 ++++-- cli/certs.go | 121 +++++++++++++++++++- mockery.yaml | 4 + mocks/csr.go | 249 ++++++++++++++++++++++++++++++++++++++++++ mocks/service.go | 96 ++++++++-------- mocks/uuid.go | 35 ++++++ sdk/mocks/sdk.go | 224 +++++++++++++++++++++++++++++++++++++ sdk/sdk.go | 26 +++-- service.go | 11 +- tracing/certs.go | 4 +- 17 files changed, 760 insertions(+), 185 deletions(-) create mode 100644 mocks/csr.go create mode 100644 mocks/uuid.go diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 409afb2..29919a9 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -317,7 +317,7 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint { return createCSRRes{created: false}, err } - csr, err := svc.CreateCSR(ctx, req.metadata, req.entityID, req.privKey) + csr, err := svc.CreateCSR(ctx, req.Metadata, req.Metadata.EntityID, req.privKey) if err != nil { return createCSRRes{created: false}, err } @@ -329,19 +329,19 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint { } } -func processCSREndpoint(svc certs.Service) endpoint.Endpoint { +func signCSREndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(processCSRReq) + req := request.(SignCSRReq) if err := req.validate(); err != nil { - return processCSRRes{processed: false}, err + return signCSRRes{processed: false}, err } - err = svc.ProcessCSR(ctx, req.csrID, req.approve) + err = svc.SignCSR(ctx, req.csrID, req.approve) if err != nil { - return processCSRRes{processed: false}, err + return signCSRRes{processed: false}, err } - return processCSRRes{ + return signCSRRes{ processed: true, }, nil } diff --git a/api/http/requests.go b/api/http/requests.go index b59e106..2af4a96 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -91,24 +91,23 @@ func (req ocspReq) validate() error { } type createCSRReq struct { - metadata certs.CSRMetadata - entityID string - privKey *rsa.PrivateKey + Metadata certs.CSRMetadata `json:"metadata"` + privKey *rsa.PrivateKey } func (req createCSRReq) validate() error { - if req.entityID == "" { + if req.Metadata.EntityID == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } return nil } -type processCSRReq struct { - csrID string +type SignCSRReq struct { + csrID string approve bool } -func (req processCSRReq) validate() error { +func (req SignCSRReq) validate() error { if req.csrID == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } @@ -117,7 +116,7 @@ func (req processCSRReq) validate() error { type listCSRsReq struct { entityID string - status string + status string } func (req listCSRsReq) validate() error { diff --git a/api/http/responses.go b/api/http/responses.go index 1f708a4..f743a83 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -143,7 +143,7 @@ func (res listCertsRes) Empty() bool { type viewCertRes struct { SerialNumber string `json:"serial_number,omitempty"` Certificate string `json:"certificate,omitempty"` - Key string `json:"key,omitempty,omitempty"` + Key string `json:"key,omitempty"` Revoked bool `json:"revoked,omitempty"` ExpiryTime time.Time `json:"expiry_time,omitempty"` EntityID string `json:"entity_id,omitempty"` @@ -224,19 +224,19 @@ func (res createCSRRes) Empty() bool { return false } -type processCSRRes struct { +type signCSRRes struct { processed bool } -func (res processCSRRes) Code() int { +func (res signCSRRes) Code() int { return http.StatusOK } -func (res processCSRRes) Headers() map[string]string { +func (res signCSRRes) Headers() map[string]string { return map[string]string{} } -func (res processCSRRes) Empty() bool { +func (res signCSRRes) Empty() bool { return true } diff --git a/api/http/transport.go b/api/http/transport.go index ab949f9..52622a8 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -32,16 +32,6 @@ const ( limitKey = "limit" entityKey = "entity_id" commonName = "common_name" - organization = "organization" - orgUnit = "organization_unit" - country = "country" - province = "province" - locality = "locality" - streetAddress = "street_address" - postalCode = "postal_code" - emailAddresses = "email_addresses" - dnsNames = "dns_names" - ipAddresses = "ip_addresses" approve = "approve" status = "status" token = "token" @@ -157,7 +147,7 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http opts..., ), "").ServeHTTP) r.Patch("/{id}", otelhttp.NewHandler(kithttp.NewServer( - processCSREndpoint(svc), + signCSREndpoint(svc), decodeUpdateCSR, EncodeResponse, opts..., @@ -300,84 +290,11 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { } func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { - o, err := readStringQuery(r, organization, "") - if err != nil { - return nil, err - } - - ou, err := readStringQuery(r, orgUnit, "") - if err != nil { - return nil, err - } - - c, err := readStringQuery(r, country, "") - if err != nil { - return nil, err - } - - p, err := readStringQuery(r, province, "") - if err != nil { - return nil, err - } - - l, err := readStringQuery(r, locality, "") - if err != nil { - return nil, err - } - - s, err := readStringQuery(r, streetAddress, "") - if err != nil { - return nil, err - } - - po, err := readStringQuery(r, postalCode, "") - if err != nil { + req := createCSRReq{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, err } - entity, err := readStringQuery(r, entityKey, "") - if err != nil { - return nil, err - } - - cn, err := readStringQuery(r, commonName, "") - if err != nil { - return nil, err - } - - e, err := readStringQuery(r, emailAddresses, "") - if err != nil { - return nil, err - } - - d, err := readStringQuery(r, dnsNames, "") - if err != nil { - return nil, err - } - - i, err := readStringQuery(r, ipAddresses, "") - if err != nil { - return nil, err - } - - req := createCSRReq{ - metadata: certs.CSRMetadata{ - CommonName: cn, - Organization: []string{o}, - OrganizationalUnit: []string{ou}, - Country: []string{c}, - Province: []string{p}, - Locality: []string{l}, - StreetAddress: []string{s}, - PostalCode: []string{po}, - DNSNames: []string{d}, - IPAddresses: []string{i}, - EmailAddresses: []string{e}, - }, - entityID: entity, - // privKey: , - } - return req, nil } @@ -387,7 +304,7 @@ func decodeUpdateCSR(_ context.Context, r *http.Request) (interface{}, error) { return nil, err } - req := processCSRReq{ + req := SignCSRReq{ csrID: chi.URLParam(r, "id"), approve: app, } diff --git a/api/logging.go b/api/logging.go index 05852e1..905ed28 100644 --- a/api/logging.go +++ b/api/logging.go @@ -194,7 +194,7 @@ func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada return lm.svc.CreateCSR(ctx, meta, entityID, key...) } -func (lm *loggingMiddleware) ProcessCSR(ctx context.Context, csrID string, approve bool) (err error) { +func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) (err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method process_csr took %s to complete", time.Since(begin)) if err != nil { @@ -203,7 +203,7 @@ func (lm *loggingMiddleware) ProcessCSR(ctx context.Context, csrID string, appro } lm.logger.Info(message) }(time.Now()) - return lm.svc.ProcessCSR(ctx, csrID, approve) + return lm.svc.SignCSR(ctx, csrID, approve) } func (lm *loggingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (cp certs.CSRPage, err error) { diff --git a/api/metrics.go b/api/metrics.go index 1aa5e7b..9747e10 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -145,12 +145,12 @@ func (mm *metricsMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada return mm.svc.CreateCSR(ctx, meta, entityID, key...) } -func (mm *metricsMiddleware) ProcessCSR(ctx context.Context, csrID string, approve bool) error { +func (mm *metricsMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { defer func(begin time.Time) { mm.counter.With("method", "process_csr").Add(1) mm.latency.With("method", "process_csr").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.ProcessCSR(ctx, csrID, approve) + return mm.svc.SignCSR(ctx, csrID, approve) } func (mm *metricsMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { diff --git a/certs.go b/certs.go index c14fc76..59ae29a 100644 --- a/certs.go +++ b/certs.go @@ -87,6 +87,7 @@ type PageMetadata struct { type CSRMetadata struct { CommonName string `json:"common_name"` + EntityID string `json:"entity_id"` Organization []string `json:"organization"` OrganizationalUnit []string `json:"organizational_unit"` Country []string `json:"country"` @@ -185,8 +186,8 @@ type Service interface { // CreateCSR creates a new Certificate Signing Request CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string, privKey ...*rsa.PrivateKey) (CSR, error) - // ProcessCSR processes a pending CSR and either approves or rejects it - ProcessCSR(ctx context.Context, csrID string, approve bool) error + // SignCSR processes a pending CSR and either approves or rejects it + SignCSR(ctx context.Context, csrID string, approve bool) error // ListCSRs returns a list of CSRs based on filter criteria ListCSRs(ctx context.Context, entityID string, status string) (CSRPage, error) diff --git a/certs_test.go b/certs_test.go index 28616e2..184773f 100644 --- a/certs_test.go +++ b/certs_test.go @@ -34,10 +34,12 @@ var ( func TestIssueCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -76,12 +78,14 @@ func TestIssueCert(t *testing.T) { func TestRevokeCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() invalidSerialNumber := "invalid serial number" repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -130,10 +134,12 @@ func TestRevokeCert(t *testing.T) { func TestGetCertDownloadToken(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -160,6 +166,8 @@ func TestGetCertDownloadToken(t *testing.T) { func TestGetCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ExpiresAt: time.Now().Add(time.Minute * 5).Unix(), Issuer: certs.Organization, Subject: "certs"}) validToken, err := jwtToken.SignedString([]byte(serialNumber)) @@ -167,7 +175,7 @@ func TestGetCert(t *testing.T) { repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -211,6 +219,8 @@ func TestGetCert(t *testing.T) { func TestRenewCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() serialNumber := big.NewInt(1) expiredSerialNumber := big.NewInt(2) @@ -262,7 +272,7 @@ func TestRenewCert(t *testing.T) { repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -351,10 +361,12 @@ func TestRenewCert(t *testing.T) { func TestGetEntityID(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -382,10 +394,12 @@ func TestGetEntityID(t *testing.T) { func TestListCerts(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -419,6 +433,8 @@ func TestListCerts(t *testing.T) { func TestGenerateCRL(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) template := &x509.Certificate{ @@ -440,7 +456,7 @@ func TestGenerateCRL(t *testing.T) { {Type: certs.IntermediateCA, Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})}, }, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() diff --git a/cli/certs.go b/cli/certs.go index c1bd3bf..4159236 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -5,8 +5,10 @@ package cli import ( "encoding/json" + "fmt" "os" + "github.com/absmach/certs/errors" ctxsdk "github.com/absmach/certs/sdk" "github.com/spf13/cobra" ) @@ -20,7 +22,7 @@ func SetSDK(s ctxsdk.SDK) { var cmdCerts = []cobra.Command{ { - Use: "get [all | ]", + Use: "get [all | ]", Short: "Get certificate", Long: `Gets a certificate for a given entity ID or all certificates.`, Run: func(cmd *cobra.Command, args []string) { @@ -238,6 +240,123 @@ var cmdCerts = []cobra.Command{ logJSONCmd(*cmd, token) }, }, + { + Use: "csr ", + Short: "Create CSR", + Long: `Creates a CSR.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) > 3 || len(args) == 0 { + logUsageCmd(*cmd, cmd.Use) + return + } + + var pm ctxsdk.PageMetadata + if err := json.Unmarshal([]byte(args[0]), &pm); err != nil { + logErrorCmd(*cmd, err) + return + } + var csr ctxsdk.CSR + var err error + if len(args) == 1 { + csr, err = sdk.CreateCSR(pm, "") + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, csr) + return + } + csr, err = sdk.CreateCSR(pm, args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, csr) + }, + }, + { + Use: "sign ", + Short: "Sign CSR", + Long: `Signs a CSR for a given csr id.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + var sign bool + switch args[1] { + case "true": + sign = true + case "false": + sign = false + default: + logErrorCmd(*cmd, errors.NewSDKError(fmt.Errorf("unknown type"))) + return + } + + err := sdk.SignCSR(args[0], sign) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logOKCmd(*cmd) + }, + }, + { + Use: "get-csr [all | ] ", + Short: "Get csr", + Long: `Gets CSRs for a given entity ID or all CSR.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + if args[0] == "all" { + pm := ctxsdk.PageMetadata{ + Limit: Limit, + Offset: Offset, + Status: args[1], + } + page, err := sdk.ListCSRs(pm) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, page) + return + } + pm := ctxsdk.PageMetadata{ + EntityID: args[0], + Limit: Limit, + Offset: Offset, + Status: args[1], + } + page, err := sdk.ListCSRs(pm) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, page) + }, + }, + { + Use: "view-csr ", + Short: "View CSR", + Long: `Views a CSR for a given csr id.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + logUsageCmd(*cmd, cmd.Use) + return + } + cert, err := sdk.RetrieveCSR(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + }, + }, } // NewCertsCmd returns certificate command. diff --git a/mockery.yaml b/mockery.yaml index 8db5b4a..dfbf09a 100644 --- a/mockery.yaml +++ b/mockery.yaml @@ -16,6 +16,10 @@ packages: config: dir: "{{.InterfaceDir}}/mocks" filename: "repository.go" + CSRRepository: + config: + dir: "{{.InterfaceDir}}/mocks" + filename: "csr.go" github.com/absmach/certs/sdk: interfaces: SDK: diff --git a/mocks/csr.go b/mocks/csr.go new file mode 100644 index 0000000..79340c1 --- /dev/null +++ b/mocks/csr.go @@ -0,0 +1,249 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + certs "github.com/absmach/certs" + + mock "github.com/stretchr/testify/mock" +) + +// MockCSRRepository is an autogenerated mock type for the CSRRepository type +type MockCSRRepository struct { + mock.Mock +} + +type MockCSRRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCSRRepository) EXPECT() *MockCSRRepository_Expecter { + return &MockCSRRepository_Expecter{mock: &_m.Mock} +} + +// CreateCSR provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) CreateCSR(_a0 context.Context, _a1 certs.CSR) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for CreateCSR") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, certs.CSR) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCSRRepository_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' +type MockCSRRepository_CreateCSR_Call struct { + *mock.Call +} + +// CreateCSR is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 certs.CSR +func (_e *MockCSRRepository_Expecter) CreateCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_CreateCSR_Call { + return &MockCSRRepository_CreateCSR_Call{Call: _e.mock.On("CreateCSR", _a0, _a1)} +} + +func (_c *MockCSRRepository_CreateCSR_Call) Run(run func(_a0 context.Context, _a1 certs.CSR)) *MockCSRRepository_CreateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.CSR)) + }) + return _c +} + +func (_c *MockCSRRepository_CreateCSR_Call) Return(_a0 error) *MockCSRRepository_CreateCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSRRepository_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSR) error) *MockCSRRepository_CreateCSR_Call { + _c.Call.Return(run) + return _c +} + +// ListCSRs provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) ListCSRs(_a0 context.Context, _a1 certs.PageMetadata) (certs.CSRPage, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for ListCSRs") + } + + var r0 certs.CSRPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) (certs.CSRPage, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) certs.CSRPage); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(certs.CSRPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, certs.PageMetadata) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSRRepository_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' +type MockCSRRepository_ListCSRs_Call struct { + *mock.Call +} + +// ListCSRs is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 certs.PageMetadata +func (_e *MockCSRRepository_Expecter) ListCSRs(_a0 interface{}, _a1 interface{}) *MockCSRRepository_ListCSRs_Call { + return &MockCSRRepository_ListCSRs_Call{Call: _e.mock.On("ListCSRs", _a0, _a1)} +} + +func (_c *MockCSRRepository_ListCSRs_Call) Run(run func(_a0 context.Context, _a1 certs.PageMetadata)) *MockCSRRepository_ListCSRs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.PageMetadata)) + }) + return _c +} + +func (_c *MockCSRRepository_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockCSRRepository_ListCSRs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSRRepository_ListCSRs_Call) RunAndReturn(run func(context.Context, certs.PageMetadata) (certs.CSRPage, error)) *MockCSRRepository_ListCSRs_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveCSR provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) RetrieveCSR(_a0 context.Context, _a1 string) (certs.CSR, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCSR") + } + + var r0 certs.CSR + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (certs.CSR, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, string) certs.CSR); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(certs.CSR) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSRRepository_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' +type MockCSRRepository_RetrieveCSR_Call struct { + *mock.Call +} + +// RetrieveCSR is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 string +func (_e *MockCSRRepository_Expecter) RetrieveCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_RetrieveCSR_Call { + return &MockCSRRepository_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", _a0, _a1)} +} + +func (_c *MockCSRRepository_RetrieveCSR_Call) Run(run func(_a0 context.Context, _a1 string)) *MockCSRRepository_RetrieveCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockCSRRepository_RetrieveCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockCSRRepository_RetrieveCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSRRepository_RetrieveCSR_Call) RunAndReturn(run func(context.Context, string) (certs.CSR, error)) *MockCSRRepository_RetrieveCSR_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCSR provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) UpdateCSR(_a0 context.Context, _a1 certs.CSR) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for UpdateCSR") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, certs.CSR) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCSRRepository_UpdateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCSR' +type MockCSRRepository_UpdateCSR_Call struct { + *mock.Call +} + +// UpdateCSR is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 certs.CSR +func (_e *MockCSRRepository_Expecter) UpdateCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_UpdateCSR_Call { + return &MockCSRRepository_UpdateCSR_Call{Call: _e.mock.On("UpdateCSR", _a0, _a1)} +} + +func (_c *MockCSRRepository_UpdateCSR_Call) Run(run func(_a0 context.Context, _a1 certs.CSR)) *MockCSRRepository_UpdateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.CSR)) + }) + return _c +} + +func (_c *MockCSRRepository_UpdateCSR_Call) Return(_a0 error) *MockCSRRepository_UpdateCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSRRepository_UpdateCSR_Call) RunAndReturn(run func(context.Context, certs.CSR) error) *MockCSRRepository_UpdateCSR_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCSRRepository creates a new instance of MockCSRRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCSRRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCSRRepository { + mock := &MockCSRRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/service.go b/mocks/service.go index deb6ff8..fd28c58 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -541,54 +541,6 @@ func (_c *MockService_OCSP_Call) RunAndReturn(run func(context.Context, string) return _c } -// ProcessCSR provides a mock function with given fields: ctx, csrID, approve -func (_m *MockService) ProcessCSR(ctx context.Context, csrID string, approve bool) error { - ret := _m.Called(ctx, csrID, approve) - - if len(ret) == 0 { - panic("no return value specified for ProcessCSR") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok { - r0 = rf(ctx, csrID, approve) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockService_ProcessCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessCSR' -type MockService_ProcessCSR_Call struct { - *mock.Call -} - -// ProcessCSR is a helper method to define mock.On call -// - ctx context.Context -// - csrID string -// - approve bool -func (_e *MockService_Expecter) ProcessCSR(ctx interface{}, csrID interface{}, approve interface{}) *MockService_ProcessCSR_Call { - return &MockService_ProcessCSR_Call{Call: _e.mock.On("ProcessCSR", ctx, csrID, approve)} -} - -func (_c *MockService_ProcessCSR_Call) Run(run func(ctx context.Context, csrID string, approve bool)) *MockService_ProcessCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(bool)) - }) - return _c -} - -func (_c *MockService_ProcessCSR_Call) Return(_a0 error) *MockService_ProcessCSR_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockService_ProcessCSR_Call) RunAndReturn(run func(context.Context, string, bool) error) *MockService_ProcessCSR_Call { - _c.Call.Return(run) - return _c -} - // RemoveCert provides a mock function with given fields: ctx, entityId func (_m *MockService) RemoveCert(ctx context.Context, entityId string) error { ret := _m.Called(ctx, entityId) @@ -967,6 +919,54 @@ func (_c *MockService_RevokeCert_Call) RunAndReturn(run func(context.Context, st return _c } +// SignCSR provides a mock function with given fields: ctx, csrID, approve +func (_m *MockService) SignCSR(ctx context.Context, csrID string, approve bool) error { + ret := _m.Called(ctx, csrID, approve) + + if len(ret) == 0 { + panic("no return value specified for SignCSR") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok { + r0 = rf(ctx, csrID, approve) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockService_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' +type MockService_SignCSR_Call struct { + *mock.Call +} + +// SignCSR is a helper method to define mock.On call +// - ctx context.Context +// - csrID string +// - approve bool +func (_e *MockService_Expecter) SignCSR(ctx interface{}, csrID interface{}, approve interface{}) *MockService_SignCSR_Call { + return &MockService_SignCSR_Call{Call: _e.mock.On("SignCSR", ctx, csrID, approve)} +} + +func (_c *MockService_SignCSR_Call) Run(run func(ctx context.Context, csrID string, approve bool)) *MockService_SignCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(bool)) + }) + return _c +} + +func (_c *MockService_SignCSR_Call) Return(_a0 error) *MockService_SignCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockService_SignCSR_Call) RunAndReturn(run func(context.Context, string, bool) error) *MockService_SignCSR_Call { + _c.Call.Return(run) + return _c +} + // ViewCert provides a mock function with given fields: ctx, serialNumber func (_m *MockService) ViewCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { ret := _m.Called(ctx, serialNumber) diff --git a/mocks/uuid.go b/mocks/uuid.go new file mode 100644 index 0000000..065daba --- /dev/null +++ b/mocks/uuid.go @@ -0,0 +1,35 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "fmt" + "sync" + + "github.com/absmach/certs/internal/uuid" +) + +// Prefix represents the prefix used to generate UUID mocks. +const Prefix = "123e4567-e89b-12d3-a456-" + +var _ uuid.IDProvider = (*uuidProviderMock)(nil) + +type uuidProviderMock struct { + mu sync.Mutex + counter int +} + +func (up *uuidProviderMock) ID() (string, error) { + up.mu.Lock() + defer up.mu.Unlock() + + up.counter++ + return fmt.Sprintf("%s%012d", Prefix, up.counter), nil +} + +// NewMock creates "mirror" uuid provider, i.e. generated +// token will hold value provided by the caller. +func NewMock() uuid.IDProvider { + return &uuidProviderMock{} +} diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 4410524..5ccd73e 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -25,6 +25,65 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { return &MockSDK_Expecter{mock: &_m.Mock} } +// CreateCSR provides a mock function with given fields: pm, privKeyPath +func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKeyPath string) (sdk.CSR, errors.SDKError) { + ret := _m.Called(pm, privKeyPath) + + if len(ret) == 0 { + panic("no return value specified for CreateCSR") + } + + var r0 sdk.CSR + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)); ok { + return rf(pm, privKeyPath) + } + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.CSR); ok { + r0 = rf(pm, privKeyPath) + } else { + r0 = ret.Get(0).(sdk.CSR) + } + + if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok { + r1 = rf(pm, privKeyPath) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' +type MockSDK_CreateCSR_Call struct { + *mock.Call +} + +// CreateCSR is a helper method to define mock.On call +// - pm sdk.PageMetadata +// - privKeyPath string +func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKeyPath interface{}) *MockSDK_CreateCSR_Call { + return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKeyPath)} +} + +func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKeyPath string)) *MockSDK_CreateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(sdk.PageMetadata), args[1].(string)) + }) + return _c +} + +func (_c *MockSDK_CreateCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *MockSDK_CreateCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { + _c.Call.Return(run) + return _c +} + // DeleteCert provides a mock function with given fields: entityID func (_m *MockSDK) DeleteCert(entityID string) errors.SDKError { ret := _m.Called(entityID) @@ -308,6 +367,64 @@ func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string return _c } +// ListCSRs provides a mock function with given fields: pm +func (_m *MockSDK) ListCSRs(pm sdk.PageMetadata) (sdk.CSRPage, errors.SDKError) { + ret := _m.Called(pm) + + if len(ret) == 0 { + panic("no return value specified for ListCSRs") + } + + var r0 sdk.CSRPage + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(sdk.PageMetadata) (sdk.CSRPage, errors.SDKError)); ok { + return rf(pm) + } + if rf, ok := ret.Get(0).(func(sdk.PageMetadata) sdk.CSRPage); ok { + r0 = rf(pm) + } else { + r0 = ret.Get(0).(sdk.CSRPage) + } + + if rf, ok := ret.Get(1).(func(sdk.PageMetadata) errors.SDKError); ok { + r1 = rf(pm) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' +type MockSDK_ListCSRs_Call struct { + *mock.Call +} + +// ListCSRs is a helper method to define mock.On call +// - pm sdk.PageMetadata +func (_e *MockSDK_Expecter) ListCSRs(pm interface{}) *MockSDK_ListCSRs_Call { + return &MockSDK_ListCSRs_Call{Call: _e.mock.On("ListCSRs", pm)} +} + +func (_c *MockSDK_ListCSRs_Call) Run(run func(pm sdk.PageMetadata)) *MockSDK_ListCSRs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(sdk.PageMetadata)) + }) + return _c +} + +func (_c *MockSDK_ListCSRs_Call) Return(_a0 sdk.CSRPage, _a1 errors.SDKError) *MockSDK_ListCSRs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_ListCSRs_Call) RunAndReturn(run func(sdk.PageMetadata) (sdk.CSRPage, errors.SDKError)) *MockSDK_ListCSRs_Call { + _c.Call.Return(run) + return _c +} + // ListCerts provides a mock function with given fields: pm func (_m *MockSDK) ListCerts(pm sdk.PageMetadata) (sdk.CertificatePage, errors.SDKError) { ret := _m.Called(pm) @@ -473,6 +590,64 @@ func (_c *MockSDK_RenewCert_Call) RunAndReturn(run func(string) errors.SDKError) return _c } +// RetrieveCSR provides a mock function with given fields: csrID +func (_m *MockSDK) RetrieveCSR(csrID string) (sdk.CSR, errors.SDKError) { + ret := _m.Called(csrID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCSR") + } + + var r0 sdk.CSR + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(string) (sdk.CSR, errors.SDKError)); ok { + return rf(csrID) + } + if rf, ok := ret.Get(0).(func(string) sdk.CSR); ok { + r0 = rf(csrID) + } else { + r0 = ret.Get(0).(sdk.CSR) + } + + if rf, ok := ret.Get(1).(func(string) errors.SDKError); ok { + r1 = rf(csrID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' +type MockSDK_RetrieveCSR_Call struct { + *mock.Call +} + +// RetrieveCSR is a helper method to define mock.On call +// - csrID string +func (_e *MockSDK_Expecter) RetrieveCSR(csrID interface{}) *MockSDK_RetrieveCSR_Call { + return &MockSDK_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", csrID)} +} + +func (_c *MockSDK_RetrieveCSR_Call) Run(run func(csrID string)) *MockSDK_RetrieveCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockSDK_RetrieveCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *MockSDK_RetrieveCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_RetrieveCSR_Call) RunAndReturn(run func(string) (sdk.CSR, errors.SDKError)) *MockSDK_RetrieveCSR_Call { + _c.Call.Return(run) + return _c +} + // RetrieveCertDownloadToken provides a mock function with given fields: serialNumber func (_m *MockSDK) RetrieveCertDownloadToken(serialNumber string) (sdk.Token, errors.SDKError) { ret := _m.Called(serialNumber) @@ -579,6 +754,55 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError return _c } +// SignCSR provides a mock function with given fields: csrID, sign +func (_m *MockSDK) SignCSR(csrID string, sign bool) errors.SDKError { + ret := _m.Called(csrID, sign) + + if len(ret) == 0 { + panic("no return value specified for SignCSR") + } + + var r0 errors.SDKError + if rf, ok := ret.Get(0).(func(string, bool) errors.SDKError); ok { + r0 = rf(csrID, sign) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + + return r0 +} + +// MockSDK_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' +type MockSDK_SignCSR_Call struct { + *mock.Call +} + +// SignCSR is a helper method to define mock.On call +// - csrID string +// - sign bool +func (_e *MockSDK_Expecter) SignCSR(csrID interface{}, sign interface{}) *MockSDK_SignCSR_Call { + return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", csrID, sign)} +} + +func (_c *MockSDK_SignCSR_Call) Run(run func(csrID string, sign bool)) *MockSDK_SignCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(bool)) + }) + return _c +} + +func (_c *MockSDK_SignCSR_Call) Return(_a0 errors.SDKError) *MockSDK_SignCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, bool) errors.SDKError) *MockSDK_SignCSR_Call { + _c.Call.Return(run) + return _c +} + // ViewCA provides a mock function with given fields: token func (_m *MockSDK) ViewCA(token string) (sdk.Certificate, errors.SDKError) { ret := _m.Called(token) diff --git a/sdk/sdk.go b/sdk/sdk.go index ee30fb1..5af513a 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -92,6 +92,8 @@ type PageMetadata struct { DNSNames []string `json:"dns_names,omitempty"` IPAddresses []string `json:"ip_addresses,omitempty"` EmailAddresses []string `json:"email_addresses,omitempty"` + Status string `json:"status,omitempty"` + Sign bool `json:"sign,omitempty"` } type Options struct { @@ -280,7 +282,7 @@ type SDK interface { // pm = sdk.CSRMetadata{CommonName: "common_name", "entity_id" } // reponse, _ := sdk.CreateCSR(pm, "privKeyPath") // fmt.Println(response) - CreateCSR(metadata CSRMetadata, entityID string, privKeyPath string) (CSR, errors.SDKError) + CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDKError) // SignCSR processes a pending CSR and either signs or rejects it // @@ -291,9 +293,9 @@ type SDK interface { // ListCSRs returns a list of CSRs based on filter criteria // - // reponse, _ := sdk.ListCSRs("entity_id", "pending") + // reponse, _ := sdk.ListCSRs(sdk.PageMetadata{EntityID: "entity_id", Status: "pending"}) // fmt.Println(response) - ListCSRs(entityID string, status string) (CSRPage, errors.SDKError) + ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) // RetrieveCSR retrieves a specific CSR by ID // @@ -585,9 +587,8 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDK if err != nil { return CSR{}, errors.NewSDKError(err) } - url := fmt.Sprintf("%s/%s", issueCertEndpoint, entityID) - - _, body, sdkerr := sdk.processRequest(http.MethodPost, url, nil, nil, http.StatusOK) + url := fmt.Sprintf("%s/%s", sdk.certsURL, csrEndpoint) + _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) if sdkerr != nil { return CSR{}, sdkerr } @@ -600,7 +601,10 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDK } func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/download-ca", certsEndpoint), pm) + pm := PageMetadata{ + Sign: sign, + } + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s", certsEndpoint, csrID), pm) if err != nil { return errors.NewSDKError(err) } @@ -612,7 +616,7 @@ func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { return nil } -func (sdk mgSDK) ListCSRs(entityID string, status string) (CSRPage, errors.SDKError) { +func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/list", csrEndpoint), pm) if err != nil { return CSRPage{}, errors.NewSDKError(err) @@ -731,6 +735,12 @@ func (pm PageMetadata) query() (string, error) { if pm.CommonName != "" { q.Add("common_name", pm.CommonName) } + if pm.Sign { + q.Add("status", "true") + } + if pm.Status != "" { + q.Add("status", pm.Status) + } return q.Encode(), nil } diff --git a/service.go b/service.go index bd9f25e..da2da54 100644 --- a/service.go +++ b/service.go @@ -93,13 +93,14 @@ func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, con // If the root CA is not found, it returns an error. func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { var err error - privKey := &rsa.PrivateKey{} + privKey := rsa.PrivateKey{} if len(key) == 0 { - privKey, err = rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + privKey = *pKey if err != nil { return Certificate{}, err } else { - privKey = key[0] + privKey = *key[0] } } serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) @@ -139,7 +140,7 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ return Certificate{}, err } dbCert := Certificate{ - Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), + Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(&privKey)}), Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), SerialNumber: template.SerialNumber.String(), EntityID: entityID, @@ -478,7 +479,7 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID return csr, nil } -func (s *service) ProcessCSR(ctx context.Context, csrID string, approve bool) error { +func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error { csr, err := s.csrRepo.RetrieveCSR(ctx, csrID) if err != nil { return errors.Wrap(ErrViewEntity, err) diff --git a/tracing/certs.go b/tracing/certs.go index 9a10765..71ee502 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -108,10 +108,10 @@ func (tm *tracingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada return tm.svc.CreateCSR(ctx, meta, entityID, key...) } -func (tm *tracingMiddleware) ProcessCSR(ctx context.Context, csrID string, approve bool) error { +func (tm *tracingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { ctx, span := tm.tracer.Start(ctx, "process_csr") defer span.End() - return tm.svc.ProcessCSR(ctx, csrID, approve) + return tm.svc.SignCSR(ctx, csrID, approve) } func (tm *tracingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { From 4d7e99d5209e7d526931cb2d4939ef9739150894 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 27 Nov 2024 14:02:28 +0300 Subject: [PATCH 06/10] fix postgres varibles Signed-off-by: nyagamunene --- postgres/csr/csr.go | 14 +++++++------- postgres/csr/init.go | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index 5f92c8e..58194ae 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -41,11 +41,11 @@ func NewRepository(db postgres.Database) certs.CSRRepository { } } -func (repo CSRRepo) CreateCSR(ctx context.Context, cert certs.CSR) error { +func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { q := ` - INSERT INTO certs (id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at) + INSERT INTO csr (id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at) VALUES (:id, :serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :processed_at)` - _, err := repo.db.NamedExecContext(ctx, q, cert) + _, err := repo.db.NamedExecContext(ctx, q, csr) if err != nil { return handleError(certs.ErrCreateEntity, err) } @@ -53,7 +53,7 @@ func (repo CSRRepo) CreateCSR(ctx context.Context, cert certs.CSR) error { } func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { - q := `UPDATE certs SET certificate = :certificate, key = :key, revoked = :revoked, expiry_time = :expiry_time WHERE serial_number = :serial_number` + q := `UPDATE csr SET certificate = :certificate, key = :key, revoked = :revoked, expiry_time = :expiry_time WHERE serial_number = :serial_number` res, err := repo.db.NamedExecContext(ctx, q, cert) if err != nil { return handleError(certs.ErrUpdateEntity, err) @@ -69,7 +69,7 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { } func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error) { - q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM certs WHERE id = $1` + q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM csr WHERE id = $1` var csr certs.CSR if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csr); err != nil { if err == sql.ErrNoRows { @@ -81,7 +81,7 @@ func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error } func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - q := `SELECT serial_number, status, submitted_at, processed_at, entity_id FROM certs %s LIMIT :limit OFFSET :offset` + q := `SELECT serial_number, status, submitted_at, processed_at, entity_id FROM csr %s LIMIT :limit OFFSET :offset` var condition string if pm.EntityID != "" { condition = `WHERE entity_id = :entity_id` @@ -111,7 +111,7 @@ func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs. csrs = append(csrs, *csr) } - q = fmt.Sprintf(`SELECT COUNT(*) FROM certs %s LIMIT :limit OFFSET :offset`, condition) + q = fmt.Sprintf(`SELECT COUNT(*) FROM csr %s LIMIT :limit OFFSET :offset`, condition) pm.Total, err = repo.total(ctx, q, params) if err != nil { return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) diff --git a/postgres/csr/init.go b/postgres/csr/init.go index 0f1534b..b31c055 100644 --- a/postgres/csr/init.go +++ b/postgres/csr/init.go @@ -20,7 +20,7 @@ func Migration() *migrate.MemoryMigrationSource { csr TEXT, private_key TEXT, entity_id VARCHAR(36), - status BOOLEAN, + status TEXT, submitted_at TIMESTAMP, processed_at TIMESTAMP )`, From 4e808a1f91610c1dcde8dce0c7084d72f012d2c0 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 28 Nov 2024 02:45:34 +0300 Subject: [PATCH 07/10] fix sdk url paths Signed-off-by: nyagamunene --- api/http/endpoint.go | 8 ++--- api/http/errors.go | 3 ++ api/http/requests.go | 13 ++++---- api/http/responses.go | 4 +-- api/http/transport.go | 44 +++++++++++++++++++++---- api/logging.go | 4 +-- api/metrics.go | 4 +-- certs.go | 75 ++++++++++++++++++++++++++++++++++++------- cli/certs.go | 12 +++++-- go.mod | 2 +- go.sum | 4 +-- mocks/service.go | 31 +++++++++--------- postgres/csr/csr.go | 73 ++++++++++++++++++++++++++--------------- postgres/csr/init.go | 2 +- sdk/certs_test.go | 12 +++---- sdk/mocks/sdk.go | 30 ++++++++--------- sdk/sdk.go | 21 ++++++------ service.go | 23 ++++++------- tracing/certs.go | 4 +-- 19 files changed, 244 insertions(+), 125 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 29919a9..4907d4a 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -333,16 +333,16 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(SignCSRReq) if err := req.validate(); err != nil { - return signCSRRes{processed: false}, err + return signCSRRes{signed: false}, err } err = svc.SignCSR(ctx, req.csrID, req.approve) if err != nil { - return signCSRRes{processed: false}, err + return signCSRRes{signed: false}, err } return signCSRRes{ - processed: true, + signed: true, }, nil } } @@ -354,7 +354,7 @@ func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint { return listCSRsRes{}, err } - cp, err := svc.ListCSRs(ctx, req.entityID, req.status) + cp, err := svc.ListCSRs(ctx, req.pm) if err != nil { return listCSRsRes{}, err } diff --git a/api/http/errors.go b/api/http/errors.go index e4123dc..dcbfb14 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -32,4 +32,7 @@ var ( // ErrMissingCN indicates missing common name. ErrMissingCN = errors.New("missing common name") + + // ErrMissingStatus indicates missing status. + ErrMissingStatus = errors.New("missing status") ) diff --git a/api/http/requests.go b/api/http/requests.go index 2af4a96..2fbb7de 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -91,8 +91,9 @@ func (req ocspReq) validate() error { } type createCSRReq struct { - Metadata certs.CSRMetadata `json:"metadata"` - privKey *rsa.PrivateKey + Metadata certs.CSRMetadata `json:"metadata"` + PrivateKey []byte `json:"private_Key"` + privKey *rsa.PrivateKey } func (req createCSRReq) validate() error { @@ -111,17 +112,17 @@ func (req SignCSRReq) validate() error { if req.csrID == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } + return nil } type listCSRsReq struct { - entityID string - status string + pm certs.PageMetadata } func (req listCSRsReq) validate() error { - if req.entityID == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + if req.pm.Status.String() == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingStatus) } return nil } diff --git a/api/http/responses.go b/api/http/responses.go index f743a83..d24709f 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -213,7 +213,7 @@ func (res createCSRRes) Code() int { return http.StatusCreated } - return http.StatusOK + return http.StatusNoContent } func (res createCSRRes) Headers() map[string]string { @@ -225,7 +225,7 @@ func (res createCSRRes) Empty() bool { } type signCSRRes struct { - processed bool + signed bool } func (res signCSRRes) Code() int { diff --git a/api/http/transport.go b/api/http/transport.go index 52622a8..8293eee 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -7,8 +7,10 @@ import ( "archive/zip" "bytes" "context" + "crypto/x509" "encoding/asn1" "encoding/json" + "encoding/pem" "fmt" "io" "log/slog" @@ -18,7 +20,7 @@ import ( "github.com/absmach/certs" "github.com/absmach/certs/errors" - "github.com/go-chi/chi" + "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" "github.com/prometheus/client_golang/prometheus/promhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -140,7 +142,7 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http opts..., ), "download_ca").ServeHTTP) r.Route("/csr", func(r chi.Router) { - r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( createCSREndpoint(svc), decodeCreateCSR, EncodeResponse, @@ -291,10 +293,22 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { req := createCSRReq{} + req.Metadata.EntityID = chi.URLParam(r, "entityID") if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, err } + if len(req.PrivateKey) > 0 { + block, _ := pem.Decode(req.PrivateKey) + if block != nil { + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, err) + } + req.privKey = privateKey + } + } + return req, nil } @@ -321,7 +335,17 @@ func decodeRetrieveCSR(_ context.Context, r *http.Request) (interface{}, error) } func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) { - s, err := readStringQuery(r, status, "all") + o, err := readNumQuery(r, offsetKey, defOffset) + if err != nil { + return nil, err + } + + l, err := readNumQuery(r, limitKey, defLimit) + if err != nil { + return nil, err + } + + s, err := readStringQuery(r, status, "") if err != nil { return nil, err } @@ -330,11 +354,19 @@ func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) { return nil, err } - req := listCSRsReq{ - entityID: e, - status: s, + stat, err := certs.ParseCSRStatus(strings.ToLower(s)) + if err != nil { + return nil, err } + req := listCSRsReq{ + pm: certs.PageMetadata{ + Offset: o, + Limit: l, + EntityID: e, + Status: stat, + }, + } return req, nil } diff --git a/api/logging.go b/api/logging.go index 905ed28..893f521 100644 --- a/api/logging.go +++ b/api/logging.go @@ -206,7 +206,7 @@ func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve return lm.svc.SignCSR(ctx, csrID, approve) } -func (lm *loggingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (cp certs.CSRPage, err error) { +func (lm *loggingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (cp certs.CSRPage, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method list_csrs took %s to complete", time.Since(begin)) if err != nil { @@ -215,7 +215,7 @@ func (lm *loggingMiddleware) ListCSRs(ctx context.Context, entityID string, stat } lm.logger.Info(message) }(time.Now()) - return lm.svc.ListCSRs(ctx, entityID, status) + return lm.svc.ListCSRs(ctx, pm) } func (lm *loggingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (csr certs.CSR, err error) { diff --git a/api/metrics.go b/api/metrics.go index 9747e10..8610458 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -153,12 +153,12 @@ func (mm *metricsMiddleware) SignCSR(ctx context.Context, csrID string, approve return mm.svc.SignCSR(ctx, csrID, approve) } -func (mm *metricsMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { +func (mm *metricsMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { defer func(begin time.Time) { mm.counter.With("method", "list_csrs").Add(1) mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.ListCSRs(ctx, entityID, status) + return mm.svc.ListCSRs(ctx, pm) } func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { diff --git a/certs.go b/certs.go index 59ae29a..f17ad50 100644 --- a/certs.go +++ b/certs.go @@ -7,6 +7,7 @@ import ( "context" "crypto/rsa" "crypto/x509" + "encoding/json" "net" "time" @@ -54,6 +55,56 @@ func CertTypeFromString(s string) (CertType, error) { } } +type CSRStatus int + +const ( + Pending CSRStatus = iota + Signed + Rejected + All +) + +const ( + pending = "pending" + signed = "signed" + rejected = "rejected" + all = "all" +) + +func (c CSRStatus) String() string { + switch c { + case Pending: + return pending + case Signed: + return signed + case Rejected: + return rejected + case All: + return all + default: + return Unknown + } +} + +func ParseCSRStatus(s string) (CSRStatus, error) { + switch s { + case pending: + return Pending, nil + case signed: + return Signed, nil + case rejected: + return Rejected, nil + case all: + return All, nil + default: + return -1, errors.New("unknown CSR status") + } +} + +func (c CSRStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(c.String()) +} + type CA struct { Type CertType Certificate *x509.Certificate @@ -78,16 +129,16 @@ type CertificatePage struct { } type PageMetadata struct { - Total uint64 `json:"total,omitempty" db:"total"` - Offset uint64 `json:"offset,omitempty" db:"offset"` - Limit uint64 `json:"limit,omitempty" db:"limit"` - EntityID string `json:"entity_id,omitempty" db:"entity_id"` - Status string `json:"status,omitempty" db:"status"` + Total uint64 `json:"total" db:"total"` + Offset uint64 `json:"offset,omitempty" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + EntityID string `json:"entity_id,omitempty" db:"entity_id"` + Status CSRStatus `json:"status,omitempty" db:"status"` } type CSRMetadata struct { + EntityID string CommonName string `json:"common_name"` - EntityID string `json:"entity_id"` Organization []string `json:"organization"` OrganizationalUnit []string `json:"organizational_unit"` Country []string `json:"country"` @@ -102,18 +153,18 @@ type CSRMetadata struct { type CSR struct { ID string `json:"id" db:"id"` - CSR []byte `json:"csr" db:"csr"` - PrivateKey []byte `json:"private_key" db:"private_key"` + CSR []byte `json:"csr,omitempty" db:"csr"` + PrivateKey []byte `json:"private_key,omitempty" db:"private_key"` EntityID string `json:"entity_id" db:"entity_id"` - Status string `json:"status" db:"status"` + Status CSRStatus `json:"status" db:"status"` SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` - ProcessedAt time.Time `json:"processed_at" db:"processed_at"` + ProcessedAt time.Time `json:"processed_at,omitempty" db:"processed_at"` SerialNumber string `json:"serial_number" db:"serial_number"` } type CSRPage struct { PageMetadata - CSRs []CSR + CSRs []CSR `json:"csrs,omitempty"` } type SubjectOptions struct { @@ -190,7 +241,7 @@ type Service interface { SignCSR(ctx context.Context, csrID string, approve bool) error // ListCSRs returns a list of CSRs based on filter criteria - ListCSRs(ctx context.Context, entityID string, status string) (CSRPage, error) + ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) // RetrieveCSR retrieves a specific CSR by ID RetrieveCSR(ctx context.Context, csrID string) (CSR, error) diff --git a/cli/certs.go b/cli/certs.go index 4159236..2cf131b 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -255,10 +255,11 @@ var cmdCerts = []cobra.Command{ logErrorCmd(*cmd, err) return } + var csr ctxsdk.CSR var err error if len(args) == 1 { - csr, err = sdk.CreateCSR(pm, "") + csr, err = sdk.CreateCSR(pm, []byte{}) if err != nil { logErrorCmd(*cmd, err) return @@ -266,7 +267,14 @@ var cmdCerts = []cobra.Command{ logJSONCmd(*cmd, csr) return } - csr, err = sdk.CreateCSR(pm, args[1]) + + data, err := os.ReadFile(args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + csr, err = sdk.CreateCSR(pm, data) if err != nil { logErrorCmd(*cmd, err) return diff --git a/go.mod b/go.mod index a579d5c..478cc3b 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.0 require ( github.com/caarlos0/env/v10 v10.0.0 github.com/fatih/color v1.18.0 - github.com/go-chi/chi v4.1.2+incompatible + github.com/go-chi/chi/v5 v5.1.0 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible diff --git a/go.sum b/go.sum index 393756d..8edef24 100644 --- a/go.sum +++ b/go.sum @@ -39,8 +39,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= -github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw= github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= diff --git a/mocks/service.go b/mocks/service.go index fd28c58..b7f2c1f 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -351,9 +351,9 @@ func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, str return _c } -// ListCSRs provides a mock function with given fields: ctx, entityID, status -func (_m *MockService) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { - ret := _m.Called(ctx, entityID, status) +// ListCSRs provides a mock function with given fields: ctx, pm +func (_m *MockService) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + ret := _m.Called(ctx, pm) if len(ret) == 0 { panic("no return value specified for ListCSRs") @@ -361,17 +361,17 @@ func (_m *MockService) ListCSRs(ctx context.Context, entityID string, status str var r0 certs.CSRPage var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.CSRPage, error)); ok { - return rf(ctx, entityID, status) + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) (certs.CSRPage, error)); ok { + return rf(ctx, pm) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.CSRPage); ok { - r0 = rf(ctx, entityID, status) + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) certs.CSRPage); ok { + r0 = rf(ctx, pm) } else { r0 = ret.Get(0).(certs.CSRPage) } - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, entityID, status) + if rf, ok := ret.Get(1).(func(context.Context, certs.PageMetadata) error); ok { + r1 = rf(ctx, pm) } else { r1 = ret.Error(1) } @@ -386,15 +386,14 @@ type MockService_ListCSRs_Call struct { // ListCSRs is a helper method to define mock.On call // - ctx context.Context -// - entityID string -// - status string -func (_e *MockService_Expecter) ListCSRs(ctx interface{}, entityID interface{}, status interface{}) *MockService_ListCSRs_Call { - return &MockService_ListCSRs_Call{Call: _e.mock.On("ListCSRs", ctx, entityID, status)} +// - pm certs.PageMetadata +func (_e *MockService_Expecter) ListCSRs(ctx interface{}, pm interface{}) *MockService_ListCSRs_Call { + return &MockService_ListCSRs_Call{Call: _e.mock.On("ListCSRs", ctx, pm)} } -func (_c *MockService_ListCSRs_Call) Run(run func(ctx context.Context, entityID string, status string)) *MockService_ListCSRs_Call { +func (_c *MockService_ListCSRs_Call) Run(run func(ctx context.Context, pm certs.PageMetadata)) *MockService_ListCSRs_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(certs.PageMetadata)) }) return _c } @@ -404,7 +403,7 @@ func (_c *MockService_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockS return _c } -func (_c *MockService_ListCSRs_Call) RunAndReturn(run func(context.Context, string, string) (certs.CSRPage, error)) *MockService_ListCSRs_Call { +func (_c *MockService_ListCSRs_Call) RunAndReturn(run func(context.Context, certs.PageMetadata) (certs.CSRPage, error)) *MockService_ListCSRs_Call { _c.Call.Return(run) return _c } diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index 58194ae..54d6c24 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -7,6 +7,8 @@ import ( "context" "database/sql" "fmt" + "log" + "strings" "github.com/absmach/certs" "github.com/absmach/certs/errors" @@ -52,9 +54,9 @@ func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { return nil } -func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { - q := `UPDATE csr SET certificate = :certificate, key = :key, revoked = :revoked, expiry_time = :expiry_time WHERE serial_number = :serial_number` - res, err := repo.db.NamedExecContext(ctx, q, cert) +func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { + q := `UPDATE csr SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, processed_at = :processed_at WHERE id = :id` + res, err := repo.db.NamedExecContext(ctx, q, csr) if err != nil { return handleError(certs.ErrUpdateEntity, err) } @@ -68,8 +70,8 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { return nil } -func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error) { - q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM csr WHERE id = $1` +func (repo CSRRepo) RetrieveCSR(ctx context.Context, id string) (certs.CSR, error) { + q := `SELECT id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at FROM csr WHERE id = $1` var csr certs.CSR if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csr); err != nil { if err == sql.ErrNoRows { @@ -81,44 +83,65 @@ func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error } func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - q := `SELECT serial_number, status, submitted_at, processed_at, entity_id FROM csr %s LIMIT :limit OFFSET :offset` - var condition string + var query []string + params := map[string]interface{}{ + "limit": pm.Limit, + "offset": pm.Offset, + } if pm.EntityID != "" { - condition = `WHERE entity_id = :entity_id` - } else { - condition = `` + query = append(query, `c.entity_id = :entity_id`) + params["entity_id"] = pm.EntityID + } + if pm.Status != certs.All { + query = append(query, `c.status = :status`) + params["status"] = pm.Status } - q = fmt.Sprintf(q, condition) - var csrs []certs.CSR - params := map[string]interface{}{ - "limit": pm.Limit, - "offset": pm.Offset, - "entity_id": pm.EntityID, + var str string + if len(query) > 0 { + str = fmt.Sprintf(`WHERE %s`, strings.Join(query, ` AND `)) } - rows, err := repo.db.NamedQueryContext(ctx, q, params) + + q := fmt.Sprintf(` + SELECT + c.id, + c.serial_number, + c.submitted_at, + c.processed_at, + c.entity_id + FROM csr c %s LIMIT :limit OFFSET :offset;`, str) + + log.Printf("Query: %s", q) + log.Printf("Parameters: %+v", pm) + rows, err := repo.db.NamedQueryContext(ctx, q, pm) if err != nil { return certs.CSRPage{}, handleError(certs.ErrViewEntity, err) } defer rows.Close() - + log.Printf("row : %+v", rows) + var csrs []certs.CSR for rows.Next() { - csr := &certs.CSR{} - if err := rows.StructScan(csr); err != nil { + csr := certs.CSR{} + if err := rows.StructScan(&csr); err != nil { + log.Printf("StructScan error: %v", err) return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) } - - csrs = append(csrs, *csr) + log.Printf("Scanned CSR: %+v", csr) + csrs = append(csrs, csr) } - q = fmt.Sprintf(`SELECT COUNT(*) FROM csr %s LIMIT :limit OFFSET :offset`, condition) - pm.Total, err = repo.total(ctx, q, params) + if len(csrs) == 0 { + log.Println("No CSRs found matching the query") + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM csr c %s;`, str) + pm.Total, err = repo.total(ctx, cq, pm) if err != nil { return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) } return certs.CSRPage{ PageMetadata: pm, - CSRs: csrs, + CSRs: csrs, }, nil } diff --git a/postgres/csr/init.go b/postgres/csr/init.go index b31c055..b95067c 100644 --- a/postgres/csr/init.go +++ b/postgres/csr/init.go @@ -20,7 +20,7 @@ func Migration() *migrate.MemoryMigrationSource { csr TEXT, private_key TEXT, entity_id VARCHAR(36), - status TEXT, + status TEXT CHECK (status IN ('pending', 'signed', 'rejected')), submitted_at TIMESTAMP, processed_at TIMESTAMP )`, diff --git a/sdk/certs_test.go b/sdk/certs_test.go index b905a1a..54f2f14 100644 --- a/sdk/certs_test.go +++ b/sdk/certs_test.go @@ -649,12 +649,12 @@ func TestDownloadCACert(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("GetChainCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) _, err := ctsdk.DownloadCA(tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + ok := svcCall.Parent.AssertCalled(t, "GetChainCA", mock.Anything, tc.token) assert.True(t, ok) } svcCall.Unset() @@ -709,12 +709,12 @@ func TestViewCA(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("GetChainCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) c, err := ctsdk.ViewCA(tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + ok := svcCall.Parent.AssertCalled(t, "GetChainCA", mock.Anything, tc.token) assert.True(t, ok) } assert.Equal(t, tc.sdkCert.Certificate, c.Certificate, fmt.Sprintf("expected: %v, got: %v", tc.sdkCert.Certificate, c.Certificate)) @@ -765,13 +765,13 @@ func TestGetCAToken(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("RetrieveCertDownloadToken", mock.Anything).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("RetrieveCAToken", mock.Anything).Return(tc.svcresp, tc.svcerr) resp, err := ctsdk.GetCAToken() assert.Equal(t, tc.err, err) if tc.err == nil { assert.Equal(t, tc.svcresp, resp.Token) - ok := svcCall.Parent.AssertCalled(t, "RetrieveCertDownloadToken", mock.Anything) + ok := svcCall.Parent.AssertCalled(t, "RetrieveCAToken", mock.Anything) assert.True(t, ok) } svcCall.Unset() diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 5ccd73e..23be435 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -25,9 +25,9 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { return &MockSDK_Expecter{mock: &_m.Mock} } -// CreateCSR provides a mock function with given fields: pm, privKeyPath -func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKeyPath string) (sdk.CSR, errors.SDKError) { - ret := _m.Called(pm, privKeyPath) +// CreateCSR provides a mock function with given fields: pm, privKey +func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKey []byte) (sdk.CSR, errors.SDKError) { + ret := _m.Called(pm, privKey) if len(ret) == 0 { panic("no return value specified for CreateCSR") @@ -35,17 +35,17 @@ func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKeyPath string) (sdk.CSR, var r0 sdk.CSR var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)); ok { - return rf(pm, privKeyPath) + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)); ok { + return rf(pm, privKey) } - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.CSR); ok { - r0 = rf(pm, privKeyPath) + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) sdk.CSR); ok { + r0 = rf(pm, privKey) } else { r0 = ret.Get(0).(sdk.CSR) } - if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok { - r1 = rf(pm, privKeyPath) + if rf, ok := ret.Get(1).(func(sdk.PageMetadata, []byte) errors.SDKError); ok { + r1 = rf(pm, privKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -62,14 +62,14 @@ type MockSDK_CreateCSR_Call struct { // CreateCSR is a helper method to define mock.On call // - pm sdk.PageMetadata -// - privKeyPath string -func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKeyPath interface{}) *MockSDK_CreateCSR_Call { - return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKeyPath)} +// - privKey []byte +func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKey interface{}) *MockSDK_CreateCSR_Call { + return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKey)} } -func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKeyPath string)) *MockSDK_CreateCSR_Call { +func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKey []byte)) *MockSDK_CreateCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].(string)) + run(args[0].(sdk.PageMetadata), args[1].([]byte)) }) return _c } @@ -79,7 +79,7 @@ func (_c *MockSDK_CreateCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *Mock return _c } -func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { +func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { _c.Call.Return(run) return _c } diff --git a/sdk/sdk.go b/sdk/sdk.go index 5af513a..86d0228 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -279,10 +279,10 @@ type SDK interface { // CreateCSR creates a new Certificate Signing Request // // example: - // pm = sdk.CSRMetadata{CommonName: "common_name", "entity_id" } - // reponse, _ := sdk.CreateCSR(pm, "privKeyPath") + // pm = sdk.CSRMetadata{CommonName: "common_name", EntityID: "entity_id" } + // reponse, _ := sdk.CreateCSR(pm, []bytes("privKey")) // fmt.Println(response) - CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDKError) + CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) // SignCSR processes a pending CSR and either signs or rejects it // @@ -570,7 +570,7 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { return tk, nil } -func (sdk mgSDK) CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDKError) { +func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) { r := csrReq{ Organization: pm.Organization, OrganizationalUnit: pm.OrganizationalUnit, @@ -582,12 +582,13 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDK DNSNames: pm.DNSNames, IPAddresses: pm.IPAddresses, EmailAddresses: pm.EmailAddresses, + PrivateKey: privKey, } d, err := json.Marshal(r) if err != nil { return CSR{}, errors.NewSDKError(err) } - url := fmt.Sprintf("%s/%s", sdk.certsURL, csrEndpoint) + url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, pm.EntityID) _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) if sdkerr != nil { return CSR{}, sdkerr @@ -604,7 +605,7 @@ func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { pm := PageMetadata{ Sign: sign, } - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s", certsEndpoint, csrID), pm) + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, csrID), pm) if err != nil { return errors.NewSDKError(err) } @@ -617,11 +618,10 @@ func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { } func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/list", csrEndpoint), pm) + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/list", certsEndpoint, csrEndpoint), pm) if err != nil { return CSRPage{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) if sdkerr != nil { return CSRPage{}, sdkerr @@ -635,9 +635,9 @@ func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { } func (sdk mgSDK) RetrieveCSR(csrID string) (CSR, errors.SDKError) { - url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, csrEndpoint, csrID) + url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, csrID) - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusCreated) if sdkerr != nil { return CSR{}, sdkerr } @@ -771,4 +771,5 @@ type csrReq struct { DNSNames []string `json:"dns_names"` IPAddresses []string `json:"ip_addresses"` EmailAddresses []string `json:"email_addresses"` + PrivateKey []byte `json:"private_key"` } diff --git a/service.go b/service.go index da2da54..3867230 100644 --- a/service.go +++ b/service.go @@ -92,16 +92,16 @@ func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, con // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { + var privKey rsa.PrivateKey var err error - privKey := rsa.PrivateKey{} if len(key) == 0 { pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) privKey = *pKey if err != nil { return Certificate{}, err - } else { - privKey = *key[0] } + } else { + privKey = *key[0] } serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { @@ -468,7 +468,7 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID CSR: csrPEM, PrivateKey: privKeyPEM, EntityID: entityID, - Status: "pending", + Status: Pending, SubmittedAt: time.Now(), } @@ -486,7 +486,7 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error } if !approve { - csr.Status = "rejected" + csr.Status = Rejected csr.ProcessedAt = time.Now() return s.csrRepo.UpdateCSR(ctx, csr) } @@ -524,19 +524,20 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error return errors.Wrap(ErrCreateEntity, err) } - csr.Status = "approved" + csr.Status = Signed csr.ProcessedAt = time.Now() csr.SerialNumber = cert.SerialNumber return s.csrRepo.UpdateCSR(ctx, csr) } -func (s *service) ListCSRs(ctx context.Context, entityID string, status string) (CSRPage, error) { - pm := PageMetadata{ - EntityID: entityID, - Status: status, +func (s *service) ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) { + cp, err := s.csrRepo.ListCSRs(ctx, pm) + if err != nil { + return CSRPage{}, errors.Wrap(ErrViewEntity, err) } - return s.csrRepo.ListCSRs(ctx, pm) + + return cp, nil } func (s *service) RetrieveCSR(ctx context.Context, csrID string) (CSR, error) { diff --git a/tracing/certs.go b/tracing/certs.go index 71ee502..ec9c699 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -114,10 +114,10 @@ func (tm *tracingMiddleware) SignCSR(ctx context.Context, csrID string, approve return tm.svc.SignCSR(ctx, csrID, approve) } -func (tm *tracingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { +func (tm *tracingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { ctx, span := tm.tracer.Start(ctx, "list_csrs") defer span.End() - return tm.svc.ListCSRs(ctx, entityID, status) + return tm.svc.ListCSRs(ctx, pm) } func (tm *tracingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { From fda3c9ea4106a8da90c9c863db5e0da94eabdc02 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 28 Nov 2024 03:47:55 +0300 Subject: [PATCH 08/10] fix failing linter Signed-off-by: nyagamunene --- certs.go | 2 +- postgres/csr/csr.go | 59 ++++++++++++++++++++++++++++++++------------- sdk/sdk.go | 6 ++--- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/certs.go b/certs.go index f17ad50..a8f9366 100644 --- a/certs.go +++ b/certs.go @@ -159,7 +159,7 @@ type CSR struct { Status CSRStatus `json:"status" db:"status"` SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` ProcessedAt time.Time `json:"processed_at,omitempty" db:"processed_at"` - SerialNumber string `json:"serial_number" db:"serial_number"` + SerialNumber string `json:"serial_number,omitempty" db:"serial_number"` } type CSRPage struct { diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index 54d6c24..580cde1 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -7,8 +7,8 @@ import ( "context" "database/sql" "fmt" - "log" "strings" + "time" "github.com/absmach/certs" "github.com/absmach/certs/errors" @@ -55,8 +55,17 @@ func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { } func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { + updateData := rawCSR{ + ID: csr.ID, + SerialNumber: csr.SerialNumber, + Status: csr.Status.String(), + PrivateKey: csr.PrivateKey, + SubmittedAt: csr.SubmittedAt, + ProcessedAt: csr.ProcessedAt, + } + q := `UPDATE csr SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, processed_at = :processed_at WHERE id = :id` - res, err := repo.db.NamedExecContext(ctx, q, csr) + res, err := repo.db.NamedExecContext(ctx, q, updateData) if err != nil { return handleError(certs.ErrUpdateEntity, err) } @@ -72,22 +81,36 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { func (repo CSRRepo) RetrieveCSR(ctx context.Context, id string) (certs.CSR, error) { q := `SELECT id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at FROM csr WHERE id = $1` - var csr certs.CSR - if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csr); err != nil { + var csrRaw rawCSR + if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csrRaw); err != nil { if err == sql.ErrNoRows { return certs.CSR{}, errors.Wrap(certs.ErrNotFound, err) } return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, err) } - return csr, nil + + status, err := certs.ParseCSRStatus(csrRaw.Status) + if err != nil { + return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, fmt.Errorf("invalid status: %s", csrRaw.Status)) + } + return certs.CSR{ + ID: csrRaw.ID, + SerialNumber: csrRaw.SerialNumber, + CSR: csrRaw.CSR, + PrivateKey: csrRaw.PrivateKey, + EntityID: csrRaw.EntityID, + Status: status, + SubmittedAt: csrRaw.SubmittedAt, + ProcessedAt: csrRaw.ProcessedAt, + }, nil } func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { var query []string params := map[string]interface{}{ - "limit": pm.Limit, - "offset": pm.Offset, - } + "limit": pm.Limit, + "offset": pm.Offset, + } if pm.EntityID != "" { query = append(query, `c.entity_id = :entity_id`) params["entity_id"] = pm.EntityID @@ -111,29 +134,20 @@ func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs. c.entity_id FROM csr c %s LIMIT :limit OFFSET :offset;`, str) - log.Printf("Query: %s", q) - log.Printf("Parameters: %+v", pm) rows, err := repo.db.NamedQueryContext(ctx, q, pm) if err != nil { return certs.CSRPage{}, handleError(certs.ErrViewEntity, err) } defer rows.Close() - log.Printf("row : %+v", rows) var csrs []certs.CSR for rows.Next() { csr := certs.CSR{} if err := rows.StructScan(&csr); err != nil { - log.Printf("StructScan error: %v", err) return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) } - log.Printf("Scanned CSR: %+v", csr) csrs = append(csrs, csr) } - if len(csrs) == 0 { - log.Println("No CSRs found matching the query") - } - cq := fmt.Sprintf(`SELECT COUNT(*) FROM csr c %s;`, str) pm.Total, err = repo.total(ctx, cq, pm) if err != nil { @@ -175,3 +189,14 @@ func handleError(wrapper, err error) error { return errors.Wrap(wrapper, err) } + +type rawCSR struct { + ID string `db:"id"` + SerialNumber string `db:"serial_number"` + CSR []byte `db:"csr"` + PrivateKey []byte `db:"private_key"` + EntityID string `db:"entity_id"` + Status string `db:"status"` + SubmittedAt time.Time `db:"submitted_at"` + ProcessedAt time.Time `db:"processed_at"` +} diff --git a/sdk/sdk.go b/sdk/sdk.go index 86d0228..7199b49 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -280,7 +280,7 @@ type SDK interface { // // example: // pm = sdk.CSRMetadata{CommonName: "common_name", EntityID: "entity_id" } - // reponse, _ := sdk.CreateCSR(pm, []bytes("privKey")) + // response, _ := sdk.CreateCSR(pm, []bytes("privKey")) // fmt.Println(response) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) @@ -293,13 +293,13 @@ type SDK interface { // ListCSRs returns a list of CSRs based on filter criteria // - // reponse, _ := sdk.ListCSRs(sdk.PageMetadata{EntityID: "entity_id", Status: "pending"}) + // response, _ := sdk.ListCSRs(sdk.PageMetadata{EntityID: "entity_id", Status: "pending"}) // fmt.Println(response) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) // RetrieveCSR retrieves a specific CSR by ID // - // reponse, _ := sdk.RetrieveCSR("csr_id") + // response, _ := sdk.RetrieveCSR("csr_id") // fmt.Println(response) RetrieveCSR(csrID string) (CSR, errors.SDKError) } From 74d5236d3f803d55ea880a64432482e2b01e4c41 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 28 Nov 2024 12:58:59 +0300 Subject: [PATCH 09/10] address comments Signed-off-by: nyagamunene --- api/http/endpoint.go | 28 ++++++++++++++-------------- api/http/transport.go | 12 ++++++------ api/logging.go | 2 +- api/metrics.go | 20 ++++++++++---------- certs.go | 10 +++++----- postgres/csr/csr.go | 20 ++++++++++---------- postgres/csr/init.go | 6 +++--- sdk/sdk.go | 18 +++++++++--------- service.go | 4 ++-- tracing/certs.go | 2 +- 10 files changed, 61 insertions(+), 61 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 4907d4a..16010c7 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -347,38 +347,38 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint { } } -func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint { +func retrieveCSREndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(listCSRsReq) + req := request.(retrieveCSRReq) if err := req.validate(); err != nil { - return listCSRsRes{}, err + return retrieveCSRRes{}, err } - cp, err := svc.ListCSRs(ctx, req.pm) + csr, err := svc.RetrieveCSR(ctx, req.csrID) if err != nil { - return listCSRsRes{}, err + return retrieveCSRRes{}, err } - return listCSRsRes{ - cp, + return retrieveCSRRes{ + CSR: csr, }, nil } } -func retrieveCSREndpoint(svc certs.Service) endpoint.Endpoint { +func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(retrieveCSRReq) + req := request.(listCSRsReq) if err := req.validate(); err != nil { - return retrieveCSRRes{}, err + return listCSRsRes{}, err } - csr, err := svc.RetrieveCSR(ctx, req.csrID) + cp, err := svc.ListCSRs(ctx, req.pm) if err != nil { - return retrieveCSRRes{}, err + return listCSRsRes{}, err } - return retrieveCSRRes{ - CSR: csr, + return listCSRsRes{ + cp, }, nil } } diff --git a/api/http/transport.go b/api/http/transport.go index 8293eee..87e22be 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -141,31 +141,31 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http encodeCADownloadResponse, opts..., ), "download_ca").ServeHTTP) - r.Route("/csr", func(r chi.Router) { + r.Route("/csrs", func(r chi.Router) { r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( createCSREndpoint(svc), decodeCreateCSR, EncodeResponse, opts..., - ), "").ServeHTTP) + ), "create_csr").ServeHTTP) r.Patch("/{id}", otelhttp.NewHandler(kithttp.NewServer( signCSREndpoint(svc), decodeUpdateCSR, EncodeResponse, opts..., - ), "").ServeHTTP) + ), "sign_csr").ServeHTTP) r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer( retrieveCSREndpoint(svc), decodeRetrieveCSR, EncodeResponse, opts..., - ), "").ServeHTTP) - r.Get("/list", otelhttp.NewHandler(kithttp.NewServer( + ), "view_csr").ServeHTTP) + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( listCSRsEndpoint(svc), decodeListCSR, EncodeResponse, opts..., - ), "").ServeHTTP) + ), "list_csrs").ServeHTTP) }) }) diff --git a/api/logging.go b/api/logging.go index 893f521..80144f6 100644 --- a/api/logging.go +++ b/api/logging.go @@ -196,7 +196,7 @@ func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) (err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method process_csr took %s to complete", time.Since(begin)) + message := fmt.Sprintf("Method sign_csr took %s to complete", time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return diff --git a/api/metrics.go b/api/metrics.go index 8610458..b0c6f4e 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -147,20 +147,12 @@ func (mm *metricsMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada func (mm *metricsMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { defer func(begin time.Time) { - mm.counter.With("method", "process_csr").Add(1) - mm.latency.With("method", "process_csr").Observe(time.Since(begin).Seconds()) + mm.counter.With("method", "sign_csr").Add(1) + mm.latency.With("method", "sign_csr").Observe(time.Since(begin).Seconds()) }(time.Now()) return mm.svc.SignCSR(ctx, csrID, approve) } -func (mm *metricsMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - defer func(begin time.Time) { - mm.counter.With("method", "list_csrs").Add(1) - mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.ListCSRs(ctx, pm) -} - func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { defer func(begin time.Time) { mm.counter.With("method", "retrieve_csr").Add(1) @@ -168,3 +160,11 @@ func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (cer }(time.Now()) return mm.svc.RetrieveCSR(ctx, csrID) } + +func (mm *metricsMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_csrs").Add(1) + mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.ListCSRs(ctx, pm) +} diff --git a/certs.go b/certs.go index a8f9366..6654c5b 100644 --- a/certs.go +++ b/certs.go @@ -158,7 +158,7 @@ type CSR struct { EntityID string `json:"entity_id" db:"entity_id"` Status CSRStatus `json:"status" db:"status"` SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` - ProcessedAt time.Time `json:"processed_at,omitempty" db:"processed_at"` + SignedAt time.Time `json:"signed_at,omitempty" db:"signed_at"` SerialNumber string `json:"serial_number,omitempty" db:"serial_number"` } @@ -240,11 +240,11 @@ type Service interface { // SignCSR processes a pending CSR and either approves or rejects it SignCSR(ctx context.Context, csrID string, approve bool) error - // ListCSRs returns a list of CSRs based on filter criteria - ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) - // RetrieveCSR retrieves a specific CSR by ID RetrieveCSR(ctx context.Context, csrID string) (CSR, error) + + // ListCSRs returns a list of CSRs based on filter criteria + ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) } type Repository interface { @@ -273,6 +273,6 @@ type Repository interface { type CSRRepository interface { CreateCSR(context.Context, CSR) error UpdateCSR(context.Context, CSR) error - ListCSRs(context.Context, PageMetadata) (CSRPage, error) RetrieveCSR(context.Context, string) (CSR, error) + ListCSRs(context.Context, PageMetadata) (CSRPage, error) } diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index 580cde1..1cc3e3e 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -45,8 +45,8 @@ func NewRepository(db postgres.Database) certs.CSRRepository { func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { q := ` - INSERT INTO csr (id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at) - VALUES (:id, :serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :processed_at)` + INSERT INTO csr (id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at) + VALUES (:id, :serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :signed_at)` _, err := repo.db.NamedExecContext(ctx, q, csr) if err != nil { return handleError(certs.ErrCreateEntity, err) @@ -61,10 +61,10 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { Status: csr.Status.String(), PrivateKey: csr.PrivateKey, SubmittedAt: csr.SubmittedAt, - ProcessedAt: csr.ProcessedAt, + SignedAt: csr.SignedAt, } - q := `UPDATE csr SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, processed_at = :processed_at WHERE id = :id` + q := `UPDATE csr SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, signed_at = :signed_at WHERE id = :id` res, err := repo.db.NamedExecContext(ctx, q, updateData) if err != nil { return handleError(certs.ErrUpdateEntity, err) @@ -80,7 +80,7 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { } func (repo CSRRepo) RetrieveCSR(ctx context.Context, id string) (certs.CSR, error) { - q := `SELECT id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at FROM csr WHERE id = $1` + q := `SELECT id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at FROM csrs WHERE id = $1` var csrRaw rawCSR if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csrRaw); err != nil { if err == sql.ErrNoRows { @@ -101,7 +101,7 @@ func (repo CSRRepo) RetrieveCSR(ctx context.Context, id string) (certs.CSR, erro EntityID: csrRaw.EntityID, Status: status, SubmittedAt: csrRaw.SubmittedAt, - ProcessedAt: csrRaw.ProcessedAt, + SignedAt: csrRaw.SignedAt, }, nil } @@ -130,9 +130,9 @@ func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs. c.id, c.serial_number, c.submitted_at, - c.processed_at, + c.signed_at, c.entity_id - FROM csr c %s LIMIT :limit OFFSET :offset;`, str) + FROM csrs c %s LIMIT :limit OFFSET :offset;`, str) rows, err := repo.db.NamedQueryContext(ctx, q, pm) if err != nil { @@ -148,7 +148,7 @@ func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs. csrs = append(csrs, csr) } - cq := fmt.Sprintf(`SELECT COUNT(*) FROM csr c %s;`, str) + cq := fmt.Sprintf(`SELECT COUNT(*) FROM csrs c %s;`, str) pm.Total, err = repo.total(ctx, cq, pm) if err != nil { return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) @@ -198,5 +198,5 @@ type rawCSR struct { EntityID string `db:"entity_id"` Status string `db:"status"` SubmittedAt time.Time `db:"submitted_at"` - ProcessedAt time.Time `db:"processed_at"` + SignedAt time.Time `db:"signed_at"` } diff --git a/postgres/csr/init.go b/postgres/csr/init.go index b95067c..9ef8156 100644 --- a/postgres/csr/init.go +++ b/postgres/csr/init.go @@ -12,9 +12,9 @@ func Migration() *migrate.MemoryMigrationSource { return &migrate.MemoryMigrationSource{ Migrations: []*migrate.Migration{ { - Id: "csr_1", + Id: "csrs_1", Up: []string{ - `CREATE TABLE IF NOT EXISTS csr ( + `CREATE TABLE IF NOT EXISTS csrs ( id VARCHAR(36) PRIMARY KEY, serial_number VARCHAR(40), csr TEXT, @@ -22,7 +22,7 @@ func Migration() *migrate.MemoryMigrationSource { entity_id VARCHAR(36), status TEXT CHECK (status IN ('pending', 'signed', 'rejected')), submitted_at TIMESTAMP, - processed_at TIMESTAMP + signed_at TIMESTAMP )`, }, Down: []string{ diff --git a/sdk/sdk.go b/sdk/sdk.go index 7199b49..5c73fbf 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -27,7 +27,7 @@ import ( const ( certsEndpoint = "certs" - csrEndpoint = "csr" + csrEndpoint = "csrs" issueCertEndpoint = "certs/issue" emptyOCSPbody = 22 ) @@ -182,7 +182,7 @@ type CSR struct { EntityID string `json:"entity_id,omitempty"` Status string `json:"status,omitempty"` SubmittedAt time.Time `json:"submitted_at,omitempty"` - ProcessedAt time.Time `json:"processed_at,omitempty"` + SignedAt time.Time `json:"signed_at,omitempty"` SerialNumber string `json:"serial_number,omitempty"` } @@ -291,17 +291,17 @@ type SDK interface { // fmt.Println(err) SignCSR(csrID string, sign bool) errors.SDKError - // ListCSRs returns a list of CSRs based on filter criteria - // - // response, _ := sdk.ListCSRs(sdk.PageMetadata{EntityID: "entity_id", Status: "pending"}) - // fmt.Println(response) - ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) - // RetrieveCSR retrieves a specific CSR by ID // // response, _ := sdk.RetrieveCSR("csr_id") // fmt.Println(response) RetrieveCSR(csrID string) (CSR, errors.SDKError) + + // ListCSRs returns a list of CSRs based on filter criteria + // + // response, _ := sdk.ListCSRs(sdk.PageMetadata{EntityID: "entity_id", Status: "pending"}) + // fmt.Println(response) + ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -618,7 +618,7 @@ func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { } func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/list", certsEndpoint, csrEndpoint), pm) + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s", certsEndpoint, csrEndpoint), pm) if err != nil { return CSRPage{}, errors.NewSDKError(err) } diff --git a/service.go b/service.go index 3867230..dc0f9f9 100644 --- a/service.go +++ b/service.go @@ -487,7 +487,7 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error if !approve { csr.Status = Rejected - csr.ProcessedAt = time.Now() + csr.SignedAt = time.Now() return s.csrRepo.UpdateCSR(ctx, csr) } @@ -525,7 +525,7 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error } csr.Status = Signed - csr.ProcessedAt = time.Now() + csr.SignedAt = time.Now() csr.SerialNumber = cert.SerialNumber return s.csrRepo.UpdateCSR(ctx, csr) diff --git a/tracing/certs.go b/tracing/certs.go index ec9c699..4faeeb5 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -109,7 +109,7 @@ func (tm *tracingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada } func (tm *tracingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { - ctx, span := tm.tracer.Start(ctx, "process_csr") + ctx, span := tm.tracer.Start(ctx, "sign_csr") defer span.End() return tm.svc.SignCSR(ctx, csrID, approve) } From 1a78e48b4a54d4d3435dc0aa7f06b19a1eb97d68 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 28 Nov 2024 13:13:36 +0300 Subject: [PATCH 10/10] Fix sign conflict Signed-off-by: nyagamunene --- postgres/csr/csr.go | 4 ++-- sdk/sdk.go | 6 +++--- service.go | 4 ++++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index 1cc3e3e..f1a2008 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -45,7 +45,7 @@ func NewRepository(db postgres.Database) certs.CSRRepository { func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { q := ` - INSERT INTO csr (id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at) + INSERT INTO csrs (id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at) VALUES (:id, :serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :signed_at)` _, err := repo.db.NamedExecContext(ctx, q, csr) if err != nil { @@ -64,7 +64,7 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { SignedAt: csr.SignedAt, } - q := `UPDATE csr SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, signed_at = :signed_at WHERE id = :id` + q := `UPDATE csrs SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, signed_at = :signed_at WHERE id = :id` res, err := repo.db.NamedExecContext(ctx, q, updateData) if err != nil { return handleError(certs.ErrUpdateEntity, err) diff --git a/sdk/sdk.go b/sdk/sdk.go index 5c73fbf..6ca2fca 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -76,9 +76,9 @@ func (c CertStatus) MarshalJSON() ([]byte, error) { } type PageMetadata struct { - Total uint64 `json:"total,omitempty"` + Total uint64 `json:"total"` Offset uint64 `json:"offset,omitempty"` - Limit uint64 `json:"limit,omitempty"` + Limit uint64 `json:"limit"` EntityID string `json:"entity_id,omitempty"` Token string `json:"token,omitempty"` CommonName string `json:"common_name,omitempty"` @@ -188,7 +188,7 @@ type CSR struct { type CSRPage struct { PageMetadata - CSRs []CSR + CSRs []CSR `json:"csrs,omitempty"` } type SDK interface { diff --git a/service.go b/service.go index dc0f9f9..a4ad34c 100644 --- a/service.go +++ b/service.go @@ -485,6 +485,10 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error return errors.Wrap(ErrViewEntity, err) } + if csr.Status != Pending { + return ErrConflict + } + if !approve { csr.Status = Rejected csr.SignedAt = time.Now()