From cbaee01ed536d4ab5d317eadeb2ce742cb2be936 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 20 Nov 2024 18:22:40 +0300 Subject: [PATCH] Add repository for CSR Signed-off-by: nyagamunene --- certs.go | 26 ++++++-- cmd/certs/main.go | 6 +- postgres/csr/csr.go | 154 +++++++++++++++++++++++++++++++++++++++++++ postgres/csr/init.go | 22 +++---- service.go | 142 +++++++++++++++++++++++++++++++++++++-- 5 files changed, 324 insertions(+), 26 deletions(-) diff --git a/certs.go b/certs.go index 673b6bb..c1c221f 100644 --- a/certs.go +++ b/certs.go @@ -47,13 +47,18 @@ type CSRMetadata struct { } type CSR struct { - CSR []byte `json:"csr"` - PrivateKey []byte `json:"private_key"` - EntityID string `json:"entity_id"` - Status string `json:"status"` - SubmittedAt time.Time `json:"submitted_at"` - ProcessedAt time.Time `json:"processed_at"` - SerialNumber string `json:"serial_number"` + CSR []byte `json:"csr" db:"csr"` + PrivateKey []byte `json:"private_key" db:"private_key"` + EntityID string `json:"entity_id" db:"entity_id"` + Status string `json:"status" db:"status"` + SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` + ProcessedAt time.Time `json:"processed_at" db:"processed_at"` + SerialNumber string `json:"serial_number" db:"serial_number"` +} + +type CSRPage struct { + PageMetadata + CSRs []CSR } type Service interface { @@ -133,3 +138,10 @@ type Repository interface { // RemoveCert deletes cert from database. RemoveCert(ctx context.Context, entityId string) error } + +type CSRRepository interface { + CreateCSR(context.Context, CSR) error + UpdateCSR(context.Context, CSR) error + ListCSRs(context.Context, PageMetadata) (CSRPage, error) + RetrieveCSR(context.Context, string) (CSR, error) +} diff --git a/cmd/certs/main.go b/cmd/certs/main.go index e596483..6047ac1 100644 --- a/cmd/certs/main.go +++ b/cmd/certs/main.go @@ -24,7 +24,8 @@ import ( grpcserver "github.com/absmach/certs/internal/server/grpc" httpserver "github.com/absmach/certs/internal/server/http" "github.com/absmach/certs/internal/uuid" - cpostgres "github.com/absmach/certs/postgres" + cpostgres "github.com/absmach/certs/postgres/certs" + csrpostgres "github.com/absmach/certs/postgres/csr" "github.com/absmach/certs/tracing" "github.com/caarlos0/env/v10" "github.com/jmoiron/sqlx" @@ -146,7 +147,8 @@ func main() { func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, logger *slog.Logger, dbConfig pgClient.Config, config *certs.Config) (certs.Service, error) { database := postgres.NewDatabase(db, dbConfig, tracer) repo := cpostgres.NewRepository(database) - svc, err := certs.NewService(ctx, repo, config) + csrRepo := csrpostgres.NewRepository(database) + svc, err := certs.NewService(ctx, repo, csrRepo, config) if err != nil { return nil, err } diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index e69de29..f4d52dc 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -0,0 +1,154 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/absmach/certs" + "github.com/absmach/certs/errors" + "github.com/absmach/certs/internal/postgres" + "github.com/jackc/pgx/v5/pgconn" +) + +// Postgres error codes: +// https://www.postgresql.org/docs/current/errcodes-appendix.html +const ( + errDuplicate = "23505" // unique_violation + errTruncation = "22001" // string_data_right_truncation + errFK = "23503" // foreign_key_violation + errInvalid = "22P02" // invalid_text_representation + errUntranslatable = "22P05" // untranslatable_character + errInvalidChar = "22021" // character_not_in_repertoire +) + +var ( + ErrConflict = errors.New("entity already exists") + ErrMalformedEntity = errors.New("malformed entity") + ErrCreateEntity = errors.New("failed to create entity") +) + +type CSRRepo struct { + db postgres.Database +} + +func NewRepository(db postgres.Database) certs.CSRRepository { + return CSRRepo{ + db: db, + } +} + +func (repo CSRRepo) CreateCSR(ctx context.Context, cert certs.CSR) error { + q := ` + INSERT INTO certs (serial_number, csr, private_key, entity_id, status, submitted_at, processed_at) + VALUES (:serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :processed_at)` + _, err := repo.db.NamedExecContext(ctx, q, cert) + if err != nil { + return handleError(certs.ErrCreateEntity, err) + } + return nil +} + +func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { + q := `UPDATE certs SET certificate = :certificate, key = :key, revoked = :revoked, expiry_time = :expiry_time WHERE serial_number = :serial_number` + res, err := repo.db.NamedExecContext(ctx, q, cert) + if err != nil { + return handleError(certs.ErrUpdateEntity, err) + } + count, err := res.RowsAffected() + if err != nil { + return errors.Wrap(certs.ErrUpdateEntity, err) + } + if count == 0 { + return certs.ErrNotFound + } + return nil +} + +func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error) { + q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM certs WHERE serial_number = $1` + var csr certs.CSR + if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csr); err != nil { + if err == sql.ErrNoRows { + return certs.CSR{}, errors.Wrap(certs.ErrNotFound, err) + } + return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, err) + } + return csr, nil +} + +func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + q := `SELECT serial_number, status, submitted_at, processed_at, entity_id FROM certs %s LIMIT :limit OFFSET :offset` + var condition string + if pm.EntityID != "" { + condition = `WHERE entity_id = :entity_id` + } else { + condition = `` + } + q = fmt.Sprintf(q, condition) + var csrs []certs.CSR + + params := map[string]interface{}{ + "limit": pm.Limit, + "offset": pm.Offset, + "entity_id": pm.EntityID, + } + rows, err := repo.db.NamedQueryContext(ctx, q, params) + if err != nil { + return certs.CSRPage{}, handleError(certs.ErrViewEntity, err) + } + defer rows.Close() + + for rows.Next() { + csr := &certs.CSR{} + if err := rows.StructScan(csr); err != nil { + return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) + } + + csrs = append(csrs, *csr) + } + + q = fmt.Sprintf(`SELECT COUNT(*) FROM certs %s LIMIT :limit OFFSET :offset`, condition) + pm.Total, err = repo.total(ctx, q, params) + if err != nil { + return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) + } + return certs.CSRPage{ + PageMetadata: pm, + CSRs: csrs, + }, nil +} + +func (repo CSRRepo) total(ctx context.Context, query string, params interface{}) (uint64, error) { + rows, err := repo.db.NamedQueryContext(ctx, query, params) + if err != nil { + return 0, err + } + defer rows.Close() + total := uint64(0) + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + return total, nil +} + +func handleError(wrapper, err error) error { + pqErr, ok := err.(*pgconn.PgError) + if ok { + switch pqErr.Code { + case errDuplicate: + return errors.Wrap(ErrConflict, err) + case errInvalid, errInvalidChar, errTruncation, errUntranslatable: + return errors.Wrap(ErrMalformedEntity, err) + case errFK: + return errors.Wrap(ErrCreateEntity, err) + } + } + + return errors.Wrap(wrapper, err) +} diff --git a/postgres/csr/init.go b/postgres/csr/init.go index 9f52b7a..ab5f113 100644 --- a/postgres/csr/init.go +++ b/postgres/csr/init.go @@ -12,21 +12,21 @@ func Migration() *migrate.MemoryMigrationSource { return &migrate.MemoryMigrationSource{ Migrations: []*migrate.Migration{ { - Id: "certs_1", + Id: "csr_1", Up: []string{ - `CREATE TABLE IF NOT EXISTS certs ( - serial_number VARCHAR(40) UNIQUE NOT NULL, - certificate TEXT, - key TEXT, - revoked BOOLEAN, - expiry_time TIMESTAMP, - entity_id VARCHAR(36), - type TEXT CHECK (type IN ('RootCA', 'IntermediateCA', 'ClientCert')), - PRIMARY KEY (serial_number) + `CREATE TABLE IF NOT EXISTS csr ( + serial_number VARCHAR(40), + csr TEXT, + private_key TEXT, + entity_id VARCHAR(36), + status BOOLEAN, + submitted_at TIMESTAMP, + processed_at TIMESTAMP, + PRIMARY KEY (entity_id) )`, }, Down: []string{ - "DROP TABLE certs", + "DROP TABLE csr", }, }, }, diff --git a/service.go b/service.go index c1dc667..91b81c2 100644 --- a/service.go +++ b/service.go @@ -124,16 +124,18 @@ type SubjectOptions struct { type service struct { repo Repository + csrRepo CSRRepository rootCA *CA intermediateCA *CA } var _ Service = (*service)(nil) -func NewService(ctx context.Context, repo Repository, config *Config) (Service, error) { +func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, config *Config) (Service, error) { var svc service svc.repo = repo + svc.csrRepo = csrRepo if err := svc.loadCACerts(ctx); err != nil { return &svc, err } @@ -159,12 +161,17 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) { - privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - if err != nil { - return Certificate{}, err +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { + var err error + privKey := &rsa.PrivateKey{} + if len(key) == 0 { + privKey, err = rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return Certificate{}, err + } else { + privKey = key[0] + } } - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return Certificate{}, err @@ -469,6 +476,115 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er return s.getConcatCAs(ctx) } +func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string) (CSR, error) { + privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: metadata.CommonName, + Organization: metadata.Organization, + OrganizationalUnit: metadata.OrganizationalUnit, + Country: metadata.Country, + Province: metadata.Province, + Locality: metadata.Locality, + StreetAddress: metadata.StreetAddress, + PostalCode: metadata.PostalCode, + }, + DNSNames: metadata.DNSNames, + } + + for _, ip := range metadata.IPAddresses { + parsedIP := net.ParseIP(ip) + if parsedIP != nil { + template.IPAddresses = append(template.IPAddresses, parsedIP) + } + } + + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + csrPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + + privKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }) + + csr := CSR{ + CSR: csrPEM, + PrivateKey: privKeyPEM, + EntityID: entityID, + Status: "pending", + SubmittedAt: time.Now(), + } + + if err := s.csrRepo.CreateCSR(ctx, csr); err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + return csr, nil +} + +func (s *service) ProcessCSR(ctx context.Context, csrID string, approve bool) error { + csr, err := s.csrRepo.RetrieveCSR(ctx, csrID) + if err != nil { + return errors.Wrap(ErrViewEntity, err) + } + + if !approve { + csr.Status = "rejected" + csr.ProcessedAt = time.Now() + return s.csrRepo.UpdateCSR(ctx, csr) + } + + block, _ := pem.Decode(csr.CSR) + if block == nil { + return errors.New("failed to parse CSR PEM") + } + + parsedCSR, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + if err := parsedCSR.CheckSignature(); err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + privKey, err := extractPrivateKey(csr.PrivateKey) + if err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + cert, err := s.IssueCert(ctx, csr.EntityID, "", nil, SubjectOptions{ + CommonName: parsedCSR.Subject.CommonName, + Organization: parsedCSR.Subject.Organization, + OrganizationalUnit: parsedCSR.Subject.OrganizationalUnit, + Country: parsedCSR.Subject.Country, + Province: parsedCSR.Subject.Province, + Locality: parsedCSR.Subject.Locality, + StreetAddress: parsedCSR.Subject.StreetAddress, + PostalCode: parsedCSR.Subject.PostalCode, + }, privKey) + if err != nil { + return errors.Wrap(ErrCreateEntity, err) + } + + csr.Status = "approved" + csr.ProcessedAt = time.Now() + csr.SerialNumber = cert.SerialNumber + + return s.csrRepo.UpdateCSR(ctx, csr) +} + func (s *service) getConcatCAs(ctx context.Context) (Certificate, error) { intermediateCert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) if err != nil { @@ -791,3 +907,17 @@ func (s *service) loadCACerts(ctx context.Context) error { } return nil } + +func extractPrivateKey(pemKey []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemKey) + if block == nil { + return nil, errors.New("failed to parse private key PEM") + } + + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + return privKey, nil +}