diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 2f43f13..16010c7 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -309,3 +309,76 @@ func viewCAEndpoint(svc certs.Service) endpoint.Endpoint { }, nil } } + +func createCSREndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(createCSRReq) + if err := req.validate(); err != nil { + return createCSRRes{created: false}, err + } + + csr, err := svc.CreateCSR(ctx, req.Metadata, req.Metadata.EntityID, req.privKey) + if err != nil { + return createCSRRes{created: false}, err + } + + return createCSRRes{ + created: true, + CSR: csr, + }, nil + } +} + +func signCSREndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(SignCSRReq) + if err := req.validate(); err != nil { + return signCSRRes{signed: false}, err + } + + err = svc.SignCSR(ctx, req.csrID, req.approve) + if err != nil { + return signCSRRes{signed: false}, err + } + + return signCSRRes{ + signed: true, + }, nil + } +} + +func retrieveCSREndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(retrieveCSRReq) + if err := req.validate(); err != nil { + return retrieveCSRRes{}, err + } + + csr, err := svc.RetrieveCSR(ctx, req.csrID) + if err != nil { + return retrieveCSRRes{}, err + } + + return retrieveCSRRes{ + CSR: csr, + }, nil + } +} + +func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(listCSRsReq) + if err := req.validate(); err != nil { + return listCSRsRes{}, err + } + + cp, err := svc.ListCSRs(ctx, req.pm) + if err != nil { + return listCSRsRes{}, err + } + + return listCSRsRes{ + cp, + }, nil + } +} diff --git a/api/http/errors.go b/api/http/errors.go index bdcff3d..dcbfb14 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -33,6 +33,6 @@ var ( // ErrMissingCN indicates missing common name. ErrMissingCN = errors.New("missing common name") - // ErrEmptyEntityID indicates that the entity id is empty. - ErrEmptyEntityID = errors.New("missing entity id") + // ErrMissingStatus indicates missing status. + ErrMissingStatus = errors.New("missing status") ) diff --git a/api/http/requests.go b/api/http/requests.go index 4031b30..2fbb7de 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -4,6 +4,8 @@ package http import ( + "crypto/rsa" + "github.com/absmach/certs" "github.com/absmach/certs/errors" "golang.org/x/crypto/ocsp" @@ -38,7 +40,7 @@ type deleteReq struct { func (req deleteReq) validate() error { if req.entityID == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrEmptyEntityID) + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } return nil } @@ -87,3 +89,51 @@ func (req ocspReq) validate() error { } return nil } + +type createCSRReq struct { + Metadata certs.CSRMetadata `json:"metadata"` + PrivateKey []byte `json:"private_Key"` + privKey *rsa.PrivateKey +} + +func (req createCSRReq) validate() error { + if req.Metadata.EntityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + return nil +} + +type SignCSRReq struct { + csrID string + approve bool +} + +func (req SignCSRReq) validate() error { + if req.csrID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + + return nil +} + +type listCSRsReq struct { + pm certs.PageMetadata +} + +func (req listCSRsReq) validate() error { + if req.pm.Status.String() == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingStatus) + } + return nil +} + +type retrieveCSRReq struct { + csrID string +} + +func (req retrieveCSRReq) validate() error { + if req.csrID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + return nil +} diff --git a/api/http/responses.go b/api/http/responses.go index f53ce6c..d24709f 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -9,6 +9,7 @@ import ( "net/http" "time" + "github.com/absmach/certs" "golang.org/x/crypto/ocsp" ) @@ -142,7 +143,7 @@ func (res listCertsRes) Empty() bool { type viewCertRes struct { SerialNumber string `json:"serial_number,omitempty"` Certificate string `json:"certificate,omitempty"` - Key string `json:"key,omitempty,omitempty"` + Key string `json:"key,omitempty"` Revoked bool `json:"revoked,omitempty"` ExpiryTime time.Time `json:"expiry_time,omitempty"` EntityID string `json:"entity_id,omitempty"` @@ -201,3 +202,72 @@ type fileDownloadRes struct { Filename string ContentType string } + +type createCSRRes struct { + certs.CSR + created bool +} + +func (res createCSRRes) Code() int { + if res.created { + return http.StatusCreated + } + + return http.StatusNoContent +} + +func (res createCSRRes) Headers() map[string]string { + return map[string]string{} +} + +func (res createCSRRes) Empty() bool { + return false +} + +type signCSRRes struct { + signed bool +} + +func (res signCSRRes) Code() int { + return http.StatusOK +} + +func (res signCSRRes) Headers() map[string]string { + return map[string]string{} +} + +func (res signCSRRes) Empty() bool { + return true +} + +type listCSRsRes struct { + certs.CSRPage +} + +func (res listCSRsRes) Code() int { + return http.StatusOK +} + +func (res listCSRsRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listCSRsRes) Empty() bool { + return false +} + +type retrieveCSRRes struct { + certs.CSR +} + +func (res retrieveCSRRes) Code() int { + return http.StatusOK +} + +func (res retrieveCSRRes) Headers() map[string]string { + return map[string]string{} +} + +func (res retrieveCSRRes) Empty() bool { + return false +} diff --git a/api/http/transport.go b/api/http/transport.go index 1d0df55..87e22be 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -7,8 +7,10 @@ import ( "archive/zip" "bytes" "context" + "crypto/x509" "encoding/asn1" "encoding/json" + "encoding/pem" "fmt" "io" "log/slog" @@ -18,7 +20,7 @@ import ( "github.com/absmach/certs" "github.com/absmach/certs/errors" - "github.com/go-chi/chi" + "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" "github.com/prometheus/client_golang/prometheus/promhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -32,6 +34,8 @@ const ( limitKey = "limit" entityKey = "entity_id" commonName = "common_name" + approve = "approve" + status = "status" token = "token" ocspStatusParam = "force_status" entityIDParam = "entityID" @@ -137,6 +141,32 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http encodeCADownloadResponse, opts..., ), "download_ca").ServeHTTP) + r.Route("/csrs", func(r chi.Router) { + r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( + createCSREndpoint(svc), + decodeCreateCSR, + EncodeResponse, + opts..., + ), "create_csr").ServeHTTP) + r.Patch("/{id}", otelhttp.NewHandler(kithttp.NewServer( + signCSREndpoint(svc), + decodeUpdateCSR, + EncodeResponse, + opts..., + ), "sign_csr").ServeHTTP) + r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer( + retrieveCSREndpoint(svc), + decodeRetrieveCSR, + EncodeResponse, + opts..., + ), "view_csr").ServeHTTP) + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listCSRsEndpoint(svc), + decodeListCSR, + EncodeResponse, + opts..., + ), "list_csrs").ServeHTTP) + }) }) r.Get("/health", certs.Health("certs", instanceID)) @@ -261,6 +291,85 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } +func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { + req := createCSRReq{} + req.Metadata.EntityID = chi.URLParam(r, "entityID") + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, err + } + + if len(req.PrivateKey) > 0 { + block, _ := pem.Decode(req.PrivateKey) + if block != nil { + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, err) + } + req.privKey = privateKey + } + } + + return req, nil +} + +func decodeUpdateCSR(_ context.Context, r *http.Request) (interface{}, error) { + app, err := readBoolQuery(r, approve, false) + if err != nil { + return nil, err + } + + req := SignCSRReq{ + csrID: chi.URLParam(r, "id"), + approve: app, + } + + return req, nil +} + +func decodeRetrieveCSR(_ context.Context, r *http.Request) (interface{}, error) { + req := retrieveCSRReq{ + csrID: chi.URLParam(r, "id"), + } + + return req, nil +} + +func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) { + o, err := readNumQuery(r, offsetKey, defOffset) + if err != nil { + return nil, err + } + + l, err := readNumQuery(r, limitKey, defLimit) + if err != nil { + return nil, err + } + + s, err := readStringQuery(r, status, "") + if err != nil { + return nil, err + } + e, err := readStringQuery(r, entityKey, "") + if err != nil { + return nil, err + } + + stat, err := certs.ParseCSRStatus(strings.ToLower(s)) + if err != nil { + return nil, err + } + + req := listCSRsReq{ + pm: certs.PageMetadata{ + Offset: o, + Limit: l, + EntityID: e, + Status: stat, + }, + } + return req, nil +} + // EncodeResponse encodes successful response. func EncodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { if ar, ok := response.(Response); ok { @@ -430,3 +539,21 @@ func readNumQuery(r *http.Request, key string, def uint64) (uint64, error) { } return v, nil } + +func readBoolQuery(r *http.Request, key string, def bool) (bool, error) { + vals := r.URL.Query()[key] + if len(vals) > 1 { + return false, ErrInvalidQueryParams + } + + if len(vals) == 0 { + return def, nil + } + + b, err := strconv.ParseBool(vals[0]) + if err != nil { + return false, errors.Wrap(ErrInvalidQueryParams, err) + } + + return b, nil +} diff --git a/api/logging.go b/api/logging.go index 731784d..80144f6 100644 --- a/api/logging.go +++ b/api/logging.go @@ -5,6 +5,7 @@ package api import ( "context" + "crypto/rsa" "crypto/x509" "fmt" "log/slog" @@ -85,7 +86,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { @@ -94,7 +95,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string } lm.logger.Info(message) }(time.Now()) - return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) + return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) } func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) { @@ -180,3 +181,51 @@ func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert }(time.Now()) return lm.svc.GetChainCA(ctx, token) } + +func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (csr certs.CSR, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method create_csr took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.CreateCSR(ctx, meta, entityID, key...) +} + +func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method sign_csr took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.SignCSR(ctx, csrID, approve) +} + +func (lm *loggingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (cp certs.CSRPage, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method list_csrs took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.ListCSRs(ctx, pm) +} + +func (lm *loggingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (csr certs.CSR, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method retrieve_csr took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.RetrieveCSR(ctx, csrID) +} diff --git a/api/metrics.go b/api/metrics.go index a5506c3..b0c6f4e 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -5,6 +5,7 @@ package api import ( "context" + "crypto/rsa" "crypto/x509" "time" @@ -71,12 +72,12 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) + return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) } func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { @@ -135,3 +136,35 @@ func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (cert }(time.Now()) return mm.svc.GetChainCA(ctx, token) } + +func (mm *metricsMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (certs.CSR, error) { + defer func(begin time.Time) { + mm.counter.With("method", "create_csr").Add(1) + mm.latency.With("method", "create_csr").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.CreateCSR(ctx, meta, entityID, key...) +} + +func (mm *metricsMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { + defer func(begin time.Time) { + mm.counter.With("method", "sign_csr").Add(1) + mm.latency.With("method", "sign_csr").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.SignCSR(ctx, csrID, approve) +} + +func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { + defer func(begin time.Time) { + mm.counter.With("method", "retrieve_csr").Add(1) + mm.latency.With("method", "retrieve_csr").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RetrieveCSR(ctx, csrID) +} + +func (mm *metricsMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_csrs").Add(1) + mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.ListCSRs(ctx, pm) +} diff --git a/certs.go b/certs.go index 7a88bee..6654c5b 100644 --- a/certs.go +++ b/certs.go @@ -5,10 +5,113 @@ package certs import ( "context" + "crypto/rsa" "crypto/x509" + "encoding/json" + "net" "time" + + "github.com/absmach/certs/errors" +) + +type CertType int + +const ( + RootCA CertType = iota + IntermediateCA + ClientCert +) + +const ( + Root = "RootCA" + Inter = "IntermediateCA" + Client = "ClientCert" + Unknown = "Unknown" +) + +func (c CertType) String() string { + switch c { + case RootCA: + return Root + case IntermediateCA: + return Inter + case ClientCert: + return Client + default: + return Unknown + } +} + +func CertTypeFromString(s string) (CertType, error) { + switch s { + case Root: + return RootCA, nil + case Inter: + return IntermediateCA, nil + case Client: + return ClientCert, nil + default: + return -1, errors.New("unknown cert type") + } +} + +type CSRStatus int + +const ( + Pending CSRStatus = iota + Signed + Rejected + All +) + +const ( + pending = "pending" + signed = "signed" + rejected = "rejected" + all = "all" ) +func (c CSRStatus) String() string { + switch c { + case Pending: + return pending + case Signed: + return signed + case Rejected: + return rejected + case All: + return all + default: + return Unknown + } +} + +func ParseCSRStatus(s string) (CSRStatus, error) { + switch s { + case pending: + return Pending, nil + case signed: + return Signed, nil + case rejected: + return Rejected, nil + case all: + return All, nil + default: + return -1, errors.New("unknown CSR status") + } +} + +func (c CSRStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(c.String()) +} + +type CA struct { + Type CertType + Certificate *x509.Certificate + PrivateKey *rsa.PrivateKey + SerialNumber string +} + type Certificate struct { SerialNumber string `db:"serial_number"` Certificate []byte `db:"certificate"` @@ -26,10 +129,67 @@ type CertificatePage struct { } type PageMetadata struct { - Total uint64 `json:"total,omitempty" db:"total"` - Offset uint64 `json:"offset,omitempty" db:"offset"` - Limit uint64 `json:"limit,omitempty" db:"limit"` - EntityID string `json:"entity_id,omitempty" db:"entity_id"` + Total uint64 `json:"total" db:"total"` + Offset uint64 `json:"offset,omitempty" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + EntityID string `json:"entity_id,omitempty" db:"entity_id"` + Status CSRStatus `json:"status,omitempty" db:"status"` +} + +type CSRMetadata struct { + EntityID string + CommonName string `json:"common_name"` + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DNSNames []string `json:"dns_names"` + IPAddresses []string `json:"ip_addresses"` + EmailAddresses []string `json:"email_addresses"` +} + +type CSR struct { + ID string `json:"id" db:"id"` + CSR []byte `json:"csr,omitempty" db:"csr"` + PrivateKey []byte `json:"private_key,omitempty" db:"private_key"` + EntityID string `json:"entity_id" db:"entity_id"` + Status CSRStatus `json:"status" db:"status"` + SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` + SignedAt time.Time `json:"signed_at,omitempty" db:"signed_at"` + SerialNumber string `json:"serial_number,omitempty" db:"serial_number"` +} + +type CSRPage struct { + PageMetadata + CSRs []CSR `json:"csrs,omitempty"` +} + +type SubjectOptions struct { + CommonName string + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` +} + +type Config struct { + CommonName string `yaml:"common_name"` + Organization []string `yaml:"organization"` + OrganizationalUnit []string `yaml:"organizational_unit"` + Country []string `yaml:"country"` + Province []string `yaml:"province"` + Locality []string `yaml:"locality"` + StreetAddress []string `yaml:"street_address"` + PostalCode []string `yaml:"postal_code"` + DNSNames []string `yaml:"dns_names"` + IPAddresses []net.IP `yaml:"ip_addresses"` + ValidityPeriod string `yaml:"validity_period"` } type Service interface { @@ -57,7 +217,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...*rsa.PrivateKey) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) @@ -73,6 +233,18 @@ type Service interface { // RemoveCert deletes a cert for a provided entityID. RemoveCert(ctx context.Context, entityId string) error + + // CreateCSR creates a new Certificate Signing Request + CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string, privKey ...*rsa.PrivateKey) (CSR, error) + + // SignCSR processes a pending CSR and either approves or rejects it + SignCSR(ctx context.Context, csrID string, approve bool) error + + // RetrieveCSR retrieves a specific CSR by ID + RetrieveCSR(ctx context.Context, csrID string) (CSR, error) + + // ListCSRs returns a list of CSRs based on filter criteria + ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) } type Repository interface { @@ -97,3 +269,10 @@ type Repository interface { // RemoveCert deletes cert from database. RemoveCert(ctx context.Context, entityId string) error } + +type CSRRepository interface { + CreateCSR(context.Context, CSR) error + UpdateCSR(context.Context, CSR) error + RetrieveCSR(context.Context, string) (CSR, error) + ListCSRs(context.Context, PageMetadata) (CSRPage, error) +} diff --git a/certs_test.go b/certs_test.go index 28616e2..184773f 100644 --- a/certs_test.go +++ b/certs_test.go @@ -34,10 +34,12 @@ var ( func TestIssueCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -76,12 +78,14 @@ func TestIssueCert(t *testing.T) { func TestRevokeCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() invalidSerialNumber := "invalid serial number" repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -130,10 +134,12 @@ func TestRevokeCert(t *testing.T) { func TestGetCertDownloadToken(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -160,6 +166,8 @@ func TestGetCertDownloadToken(t *testing.T) { func TestGetCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ExpiresAt: time.Now().Add(time.Minute * 5).Unix(), Issuer: certs.Organization, Subject: "certs"}) validToken, err := jwtToken.SignedString([]byte(serialNumber)) @@ -167,7 +175,7 @@ func TestGetCert(t *testing.T) { repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -211,6 +219,8 @@ func TestGetCert(t *testing.T) { func TestRenewCert(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() serialNumber := big.NewInt(1) expiredSerialNumber := big.NewInt(2) @@ -262,7 +272,7 @@ func TestRenewCert(t *testing.T) { repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -351,10 +361,12 @@ func TestRenewCert(t *testing.T) { func TestGetEntityID(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -382,10 +394,12 @@ func TestGetEntityID(t *testing.T) { func TestListCerts(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -419,6 +433,8 @@ func TestListCerts(t *testing.T) { func TestGenerateCRL(t *testing.T) { cRepo := new(mocks.MockRepository) + csrRepo := new(mocks.MockCSRRepository) + idProvider := mocks.NewMock() privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) template := &x509.Certificate{ @@ -440,7 +456,7 @@ func TestGenerateCRL(t *testing.T) { {Type: certs.IntermediateCA, Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})}, }, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, &config) + svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() diff --git a/cli/certs.go b/cli/certs.go index c1bd3bf..2cf131b 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -5,8 +5,10 @@ package cli import ( "encoding/json" + "fmt" "os" + "github.com/absmach/certs/errors" ctxsdk "github.com/absmach/certs/sdk" "github.com/spf13/cobra" ) @@ -20,7 +22,7 @@ func SetSDK(s ctxsdk.SDK) { var cmdCerts = []cobra.Command{ { - Use: "get [all | ]", + Use: "get [all | ]", Short: "Get certificate", Long: `Gets a certificate for a given entity ID or all certificates.`, Run: func(cmd *cobra.Command, args []string) { @@ -238,6 +240,131 @@ var cmdCerts = []cobra.Command{ logJSONCmd(*cmd, token) }, }, + { + Use: "csr ", + Short: "Create CSR", + Long: `Creates a CSR.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) > 3 || len(args) == 0 { + logUsageCmd(*cmd, cmd.Use) + return + } + + var pm ctxsdk.PageMetadata + if err := json.Unmarshal([]byte(args[0]), &pm); err != nil { + logErrorCmd(*cmd, err) + return + } + + var csr ctxsdk.CSR + var err error + if len(args) == 1 { + csr, err = sdk.CreateCSR(pm, []byte{}) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, csr) + return + } + + data, err := os.ReadFile(args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + csr, err = sdk.CreateCSR(pm, data) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, csr) + }, + }, + { + Use: "sign ", + Short: "Sign CSR", + Long: `Signs a CSR for a given csr id.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + var sign bool + switch args[1] { + case "true": + sign = true + case "false": + sign = false + default: + logErrorCmd(*cmd, errors.NewSDKError(fmt.Errorf("unknown type"))) + return + } + + err := sdk.SignCSR(args[0], sign) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logOKCmd(*cmd) + }, + }, + { + Use: "get-csr [all | ] ", + Short: "Get csr", + Long: `Gets CSRs for a given entity ID or all CSR.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + if args[0] == "all" { + pm := ctxsdk.PageMetadata{ + Limit: Limit, + Offset: Offset, + Status: args[1], + } + page, err := sdk.ListCSRs(pm) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, page) + return + } + pm := ctxsdk.PageMetadata{ + EntityID: args[0], + Limit: Limit, + Offset: Offset, + Status: args[1], + } + page, err := sdk.ListCSRs(pm) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, page) + }, + }, + { + Use: "view-csr ", + Short: "View CSR", + Long: `Views a CSR for a given csr id.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + logUsageCmd(*cmd, cmd.Use) + return + } + cert, err := sdk.RetrieveCSR(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + }, + }, } // NewCertsCmd returns certificate command. diff --git a/cmd/certs/main.go b/cmd/certs/main.go index e596483..c174da5 100644 --- a/cmd/certs/main.go +++ b/cmd/certs/main.go @@ -24,7 +24,8 @@ import ( grpcserver "github.com/absmach/certs/internal/server/grpc" httpserver "github.com/absmach/certs/internal/server/http" "github.com/absmach/certs/internal/uuid" - cpostgres "github.com/absmach/certs/postgres" + cpostgres "github.com/absmach/certs/postgres/certs" + csrpostgres "github.com/absmach/certs/postgres/csr" "github.com/absmach/certs/tracing" "github.com/caarlos0/env/v10" "github.com/jmoiron/sqlx" @@ -78,7 +79,10 @@ func main() { if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefix}); err != nil { logger.Error(err.Error()) } - db, err := pgClient.Setup(dbConfig, *cpostgres.Migration()) + cm := cpostgres.Migration() + sm := csrpostgres.Migration() + cm.Migrations = append(cm.Migrations, sm.Migrations...) + db, err := pgClient.Setup(dbConfig, *cm) if err != nil { log.Fatalf(fmt.Sprintf("Failed to connect to %s database: %s", svcName, err)) } @@ -146,7 +150,9 @@ func main() { func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, logger *slog.Logger, dbConfig pgClient.Config, config *certs.Config) (certs.Service, error) { database := postgres.NewDatabase(db, dbConfig, tracer) repo := cpostgres.NewRepository(database) - svc, err := certs.NewService(ctx, repo, config) + csrRepo := csrpostgres.NewRepository(database) + idp := uuid.New() + svc, err := certs.NewService(ctx, repo, csrRepo, config, idp) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index a579d5c..478cc3b 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.0 require ( github.com/caarlos0/env/v10 v10.0.0 github.com/fatih/color v1.18.0 - github.com/go-chi/chi v4.1.2+incompatible + github.com/go-chi/chi/v5 v5.1.0 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible diff --git a/go.sum b/go.sum index 393756d..8edef24 100644 --- a/go.sum +++ b/go.sum @@ -39,8 +39,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= -github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw= github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= diff --git a/mockery.yaml b/mockery.yaml index 8db5b4a..dfbf09a 100644 --- a/mockery.yaml +++ b/mockery.yaml @@ -16,6 +16,10 @@ packages: config: dir: "{{.InterfaceDir}}/mocks" filename: "repository.go" + CSRRepository: + config: + dir: "{{.InterfaceDir}}/mocks" + filename: "csr.go" github.com/absmach/certs/sdk: interfaces: SDK: diff --git a/mocks/csr.go b/mocks/csr.go new file mode 100644 index 0000000..79340c1 --- /dev/null +++ b/mocks/csr.go @@ -0,0 +1,249 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + certs "github.com/absmach/certs" + + mock "github.com/stretchr/testify/mock" +) + +// MockCSRRepository is an autogenerated mock type for the CSRRepository type +type MockCSRRepository struct { + mock.Mock +} + +type MockCSRRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCSRRepository) EXPECT() *MockCSRRepository_Expecter { + return &MockCSRRepository_Expecter{mock: &_m.Mock} +} + +// CreateCSR provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) CreateCSR(_a0 context.Context, _a1 certs.CSR) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for CreateCSR") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, certs.CSR) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCSRRepository_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' +type MockCSRRepository_CreateCSR_Call struct { + *mock.Call +} + +// CreateCSR is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 certs.CSR +func (_e *MockCSRRepository_Expecter) CreateCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_CreateCSR_Call { + return &MockCSRRepository_CreateCSR_Call{Call: _e.mock.On("CreateCSR", _a0, _a1)} +} + +func (_c *MockCSRRepository_CreateCSR_Call) Run(run func(_a0 context.Context, _a1 certs.CSR)) *MockCSRRepository_CreateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.CSR)) + }) + return _c +} + +func (_c *MockCSRRepository_CreateCSR_Call) Return(_a0 error) *MockCSRRepository_CreateCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSRRepository_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSR) error) *MockCSRRepository_CreateCSR_Call { + _c.Call.Return(run) + return _c +} + +// ListCSRs provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) ListCSRs(_a0 context.Context, _a1 certs.PageMetadata) (certs.CSRPage, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for ListCSRs") + } + + var r0 certs.CSRPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) (certs.CSRPage, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) certs.CSRPage); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(certs.CSRPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, certs.PageMetadata) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSRRepository_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' +type MockCSRRepository_ListCSRs_Call struct { + *mock.Call +} + +// ListCSRs is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 certs.PageMetadata +func (_e *MockCSRRepository_Expecter) ListCSRs(_a0 interface{}, _a1 interface{}) *MockCSRRepository_ListCSRs_Call { + return &MockCSRRepository_ListCSRs_Call{Call: _e.mock.On("ListCSRs", _a0, _a1)} +} + +func (_c *MockCSRRepository_ListCSRs_Call) Run(run func(_a0 context.Context, _a1 certs.PageMetadata)) *MockCSRRepository_ListCSRs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.PageMetadata)) + }) + return _c +} + +func (_c *MockCSRRepository_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockCSRRepository_ListCSRs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSRRepository_ListCSRs_Call) RunAndReturn(run func(context.Context, certs.PageMetadata) (certs.CSRPage, error)) *MockCSRRepository_ListCSRs_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveCSR provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) RetrieveCSR(_a0 context.Context, _a1 string) (certs.CSR, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCSR") + } + + var r0 certs.CSR + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (certs.CSR, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, string) certs.CSR); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(certs.CSR) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSRRepository_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' +type MockCSRRepository_RetrieveCSR_Call struct { + *mock.Call +} + +// RetrieveCSR is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 string +func (_e *MockCSRRepository_Expecter) RetrieveCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_RetrieveCSR_Call { + return &MockCSRRepository_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", _a0, _a1)} +} + +func (_c *MockCSRRepository_RetrieveCSR_Call) Run(run func(_a0 context.Context, _a1 string)) *MockCSRRepository_RetrieveCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockCSRRepository_RetrieveCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockCSRRepository_RetrieveCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSRRepository_RetrieveCSR_Call) RunAndReturn(run func(context.Context, string) (certs.CSR, error)) *MockCSRRepository_RetrieveCSR_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCSR provides a mock function with given fields: _a0, _a1 +func (_m *MockCSRRepository) UpdateCSR(_a0 context.Context, _a1 certs.CSR) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for UpdateCSR") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, certs.CSR) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCSRRepository_UpdateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCSR' +type MockCSRRepository_UpdateCSR_Call struct { + *mock.Call +} + +// UpdateCSR is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 certs.CSR +func (_e *MockCSRRepository_Expecter) UpdateCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_UpdateCSR_Call { + return &MockCSRRepository_UpdateCSR_Call{Call: _e.mock.On("UpdateCSR", _a0, _a1)} +} + +func (_c *MockCSRRepository_UpdateCSR_Call) Run(run func(_a0 context.Context, _a1 certs.CSR)) *MockCSRRepository_UpdateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.CSR)) + }) + return _c +} + +func (_c *MockCSRRepository_UpdateCSR_Call) Return(_a0 error) *MockCSRRepository_UpdateCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSRRepository_UpdateCSR_Call) RunAndReturn(run func(context.Context, certs.CSR) error) *MockCSRRepository_UpdateCSR_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCSRRepository creates a new instance of MockCSRRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCSRRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCSRRepository { + mock := &MockCSRRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/service.go b/mocks/service.go index aa91db8..b7f2c1f 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -12,6 +12,8 @@ import ( mock "github.com/stretchr/testify/mock" + rsa "crypto/rsa" + x509 "crypto/x509" ) @@ -28,6 +30,79 @@ func (_m *MockService) EXPECT() *MockService_Expecter { return &MockService_Expecter{mock: &_m.Mock} } +// CreateCSR provides a mock function with given fields: ctx, metadata, entityID, privKey +func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, entityID string, privKey ...*rsa.PrivateKey) (certs.CSR, error) { + _va := make([]interface{}, len(privKey)) + for _i := range privKey { + _va[_i] = privKey[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, metadata, entityID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CreateCSR") + } + + var r0 certs.CSR + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) (certs.CSR, error)); ok { + return rf(ctx, metadata, entityID, privKey...) + } + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) certs.CSR); ok { + r0 = rf(ctx, metadata, entityID, privKey...) + } else { + r0 = ret.Get(0).(certs.CSR) + } + + if rf, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) error); ok { + r1 = rf(ctx, metadata, entityID, privKey...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' +type MockService_CreateCSR_Call struct { + *mock.Call +} + +// CreateCSR is a helper method to define mock.On call +// - ctx context.Context +// - metadata certs.CSRMetadata +// - entityID string +// - privKey ...*rsa.PrivateKey +func (_e *MockService_Expecter) CreateCSR(ctx interface{}, metadata interface{}, entityID interface{}, privKey ...interface{}) *MockService_CreateCSR_Call { + return &MockService_CreateCSR_Call{Call: _e.mock.On("CreateCSR", + append([]interface{}{ctx, metadata, entityID}, privKey...)...)} +} + +func (_c *MockService_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, entityID string, privKey ...*rsa.PrivateKey)) *MockService_CreateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]*rsa.PrivateKey, len(args)-3) + for i, a := range args[3:] { + if a != nil { + variadicArgs[i] = a.(*rsa.PrivateKey) + } + } + run(args[0].(context.Context), args[1].(certs.CSRMetadata), args[2].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockService_CreateCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockService_CreateCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) (certs.CSR, error)) *MockService_CreateCSR_Call { + _c.Call.Return(run) + return _c +} + // GenerateCRL provides a mock function with given fields: ctx, caType func (_m *MockService) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) { ret := _m.Called(ctx, caType) @@ -201,9 +276,16 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s return _c } -// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error) { - ret := _m.Called(ctx, entityID, ttl, ipAddrs, option) +// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { + _va := make([]interface{}, len(privKey)) + for _i := range privKey { + _va[_i] = privKey[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, entityID, ttl, ipAddrs, option) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) if len(ret) == 0 { panic("no return value specified for IssueCert") @@ -211,17 +293,17 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)); ok { - return rf(ctx, entityID, ttl, ipAddrs, option) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)); ok { + return rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) certs.Certificate); ok { - r0 = rf(ctx, entityID, ttl, ipAddrs, option) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) certs.Certificate); ok { + r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions) error); ok { - r1 = rf(ctx, entityID, ttl, ipAddrs, option) + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) error); ok { + r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r1 = ret.Error(1) } @@ -240,13 +322,21 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}) *MockService_IssueCert_Call { - return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option)} +// - privKey ...*rsa.PrivateKey +func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey ...interface{}) *MockService_IssueCert_Call { + return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", + append([]interface{}{ctx, entityID, ttl, ipAddrs, option}, privKey...)...)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey)) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions)) + variadicArgs := make([]*rsa.PrivateKey, len(args)-5) + for i, a := range args[5:] { + if a != nil { + variadicArgs[i] = a.(*rsa.PrivateKey) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), variadicArgs...) }) return _c } @@ -256,7 +346,64 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { + _c.Call.Return(run) + return _c +} + +// ListCSRs provides a mock function with given fields: ctx, pm +func (_m *MockService) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + ret := _m.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListCSRs") + } + + var r0 certs.CSRPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) (certs.CSRPage, error)); ok { + return rf(ctx, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) certs.CSRPage); ok { + r0 = rf(ctx, pm) + } else { + r0 = ret.Get(0).(certs.CSRPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, certs.PageMetadata) error); ok { + r1 = rf(ctx, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' +type MockService_ListCSRs_Call struct { + *mock.Call +} + +// ListCSRs is a helper method to define mock.On call +// - ctx context.Context +// - pm certs.PageMetadata +func (_e *MockService_Expecter) ListCSRs(ctx interface{}, pm interface{}) *MockService_ListCSRs_Call { + return &MockService_ListCSRs_Call{Call: _e.mock.On("ListCSRs", ctx, pm)} +} + +func (_c *MockService_ListCSRs_Call) Run(run func(ctx context.Context, pm certs.PageMetadata)) *MockService_ListCSRs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.PageMetadata)) + }) + return _c +} + +func (_c *MockService_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockService_ListCSRs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_ListCSRs_Call) RunAndReturn(run func(context.Context, certs.PageMetadata) (certs.CSRPage, error)) *MockService_ListCSRs_Call { _c.Call.Return(run) return _c } @@ -543,6 +690,63 @@ func (_c *MockService_RetrieveCAToken_Call) RunAndReturn(run func(context.Contex return _c } +// RetrieveCSR provides a mock function with given fields: ctx, csrID +func (_m *MockService) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { + ret := _m.Called(ctx, csrID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCSR") + } + + var r0 certs.CSR + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (certs.CSR, error)); ok { + return rf(ctx, csrID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) certs.CSR); ok { + r0 = rf(ctx, csrID) + } else { + r0 = ret.Get(0).(certs.CSR) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, csrID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' +type MockService_RetrieveCSR_Call struct { + *mock.Call +} + +// RetrieveCSR is a helper method to define mock.On call +// - ctx context.Context +// - csrID string +func (_e *MockService_Expecter) RetrieveCSR(ctx interface{}, csrID interface{}) *MockService_RetrieveCSR_Call { + return &MockService_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", ctx, csrID)} +} + +func (_c *MockService_RetrieveCSR_Call) Run(run func(ctx context.Context, csrID string)) *MockService_RetrieveCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockService_RetrieveCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockService_RetrieveCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_RetrieveCSR_Call) RunAndReturn(run func(context.Context, string) (certs.CSR, error)) *MockService_RetrieveCSR_Call { + _c.Call.Return(run) + return _c +} + // RetrieveCert provides a mock function with given fields: ctx, token, serialNumber func (_m *MockService) RetrieveCert(ctx context.Context, token string, serialNumber string) (certs.Certificate, []byte, error) { ret := _m.Called(ctx, token, serialNumber) @@ -714,6 +918,54 @@ func (_c *MockService_RevokeCert_Call) RunAndReturn(run func(context.Context, st return _c } +// SignCSR provides a mock function with given fields: ctx, csrID, approve +func (_m *MockService) SignCSR(ctx context.Context, csrID string, approve bool) error { + ret := _m.Called(ctx, csrID, approve) + + if len(ret) == 0 { + panic("no return value specified for SignCSR") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok { + r0 = rf(ctx, csrID, approve) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockService_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' +type MockService_SignCSR_Call struct { + *mock.Call +} + +// SignCSR is a helper method to define mock.On call +// - ctx context.Context +// - csrID string +// - approve bool +func (_e *MockService_Expecter) SignCSR(ctx interface{}, csrID interface{}, approve interface{}) *MockService_SignCSR_Call { + return &MockService_SignCSR_Call{Call: _e.mock.On("SignCSR", ctx, csrID, approve)} +} + +func (_c *MockService_SignCSR_Call) Run(run func(ctx context.Context, csrID string, approve bool)) *MockService_SignCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(bool)) + }) + return _c +} + +func (_c *MockService_SignCSR_Call) Return(_a0 error) *MockService_SignCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockService_SignCSR_Call) RunAndReturn(run func(context.Context, string, bool) error) *MockService_SignCSR_Call { + _c.Call.Return(run) + return _c +} + // ViewCert provides a mock function with given fields: ctx, serialNumber func (_m *MockService) ViewCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { ret := _m.Called(ctx, serialNumber) diff --git a/mocks/uuid.go b/mocks/uuid.go new file mode 100644 index 0000000..065daba --- /dev/null +++ b/mocks/uuid.go @@ -0,0 +1,35 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "fmt" + "sync" + + "github.com/absmach/certs/internal/uuid" +) + +// Prefix represents the prefix used to generate UUID mocks. +const Prefix = "123e4567-e89b-12d3-a456-" + +var _ uuid.IDProvider = (*uuidProviderMock)(nil) + +type uuidProviderMock struct { + mu sync.Mutex + counter int +} + +func (up *uuidProviderMock) ID() (string, error) { + up.mu.Lock() + defer up.mu.Unlock() + + up.counter++ + return fmt.Sprintf("%s%012d", Prefix, up.counter), nil +} + +// NewMock creates "mirror" uuid provider, i.e. generated +// token will hold value provided by the caller. +func NewMock() uuid.IDProvider { + return &uuidProviderMock{} +} diff --git a/postgres/certs.go b/postgres/certs/certs.go similarity index 100% rename from postgres/certs.go rename to postgres/certs/certs.go diff --git a/postgres/certs_test.go b/postgres/certs/certs_test.go similarity index 100% rename from postgres/certs_test.go rename to postgres/certs/certs_test.go diff --git a/postgres/init.go b/postgres/certs/init.go similarity index 100% rename from postgres/init.go rename to postgres/certs/init.go diff --git a/postgres/setup_test.go b/postgres/certs/setup_test.go similarity index 100% rename from postgres/setup_test.go rename to postgres/certs/setup_test.go diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go new file mode 100644 index 0000000..f1a2008 --- /dev/null +++ b/postgres/csr/csr.go @@ -0,0 +1,202 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/absmach/certs" + "github.com/absmach/certs/errors" + "github.com/absmach/certs/internal/postgres" + "github.com/jackc/pgx/v5/pgconn" +) + +// Postgres error codes: +// https://www.postgresql.org/docs/current/errcodes-appendix.html +const ( + errDuplicate = "23505" // unique_violation + errTruncation = "22001" // string_data_right_truncation + errFK = "23503" // foreign_key_violation + errInvalid = "22P02" // invalid_text_representation + errUntranslatable = "22P05" // untranslatable_character + errInvalidChar = "22021" // character_not_in_repertoire +) + +var ( + ErrConflict = errors.New("entity already exists") + ErrMalformedEntity = errors.New("malformed entity") + ErrCreateEntity = errors.New("failed to create entity") +) + +type CSRRepo struct { + db postgres.Database +} + +func NewRepository(db postgres.Database) certs.CSRRepository { + return CSRRepo{ + db: db, + } +} + +func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { + q := ` + INSERT INTO csrs (id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at) + VALUES (:id, :serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :signed_at)` + _, err := repo.db.NamedExecContext(ctx, q, csr) + if err != nil { + return handleError(certs.ErrCreateEntity, err) + } + return nil +} + +func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { + updateData := rawCSR{ + ID: csr.ID, + SerialNumber: csr.SerialNumber, + Status: csr.Status.String(), + PrivateKey: csr.PrivateKey, + SubmittedAt: csr.SubmittedAt, + SignedAt: csr.SignedAt, + } + + q := `UPDATE csrs SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, signed_at = :signed_at WHERE id = :id` + res, err := repo.db.NamedExecContext(ctx, q, updateData) + if err != nil { + return handleError(certs.ErrUpdateEntity, err) + } + count, err := res.RowsAffected() + if err != nil { + return errors.Wrap(certs.ErrUpdateEntity, err) + } + if count == 0 { + return certs.ErrNotFound + } + return nil +} + +func (repo CSRRepo) RetrieveCSR(ctx context.Context, id string) (certs.CSR, error) { + q := `SELECT id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at FROM csrs WHERE id = $1` + var csrRaw rawCSR + if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csrRaw); err != nil { + if err == sql.ErrNoRows { + return certs.CSR{}, errors.Wrap(certs.ErrNotFound, err) + } + return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, err) + } + + status, err := certs.ParseCSRStatus(csrRaw.Status) + if err != nil { + return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, fmt.Errorf("invalid status: %s", csrRaw.Status)) + } + return certs.CSR{ + ID: csrRaw.ID, + SerialNumber: csrRaw.SerialNumber, + CSR: csrRaw.CSR, + PrivateKey: csrRaw.PrivateKey, + EntityID: csrRaw.EntityID, + Status: status, + SubmittedAt: csrRaw.SubmittedAt, + SignedAt: csrRaw.SignedAt, + }, nil +} + +func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + var query []string + params := map[string]interface{}{ + "limit": pm.Limit, + "offset": pm.Offset, + } + if pm.EntityID != "" { + query = append(query, `c.entity_id = :entity_id`) + params["entity_id"] = pm.EntityID + } + if pm.Status != certs.All { + query = append(query, `c.status = :status`) + params["status"] = pm.Status + } + + var str string + if len(query) > 0 { + str = fmt.Sprintf(`WHERE %s`, strings.Join(query, ` AND `)) + } + + q := fmt.Sprintf(` + SELECT + c.id, + c.serial_number, + c.submitted_at, + c.signed_at, + c.entity_id + FROM csrs c %s LIMIT :limit OFFSET :offset;`, str) + + rows, err := repo.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return certs.CSRPage{}, handleError(certs.ErrViewEntity, err) + } + defer rows.Close() + var csrs []certs.CSR + for rows.Next() { + csr := certs.CSR{} + if err := rows.StructScan(&csr); err != nil { + return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) + } + csrs = append(csrs, csr) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM csrs c %s;`, str) + pm.Total, err = repo.total(ctx, cq, pm) + if err != nil { + return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) + } + return certs.CSRPage{ + PageMetadata: pm, + CSRs: csrs, + }, nil +} + +func (repo CSRRepo) total(ctx context.Context, query string, params interface{}) (uint64, error) { + rows, err := repo.db.NamedQueryContext(ctx, query, params) + if err != nil { + return 0, err + } + defer rows.Close() + total := uint64(0) + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + return total, nil +} + +func handleError(wrapper, err error) error { + pqErr, ok := err.(*pgconn.PgError) + if ok { + switch pqErr.Code { + case errDuplicate: + return errors.Wrap(ErrConflict, err) + case errInvalid, errInvalidChar, errTruncation, errUntranslatable: + return errors.Wrap(ErrMalformedEntity, err) + case errFK: + return errors.Wrap(ErrCreateEntity, err) + } + } + + return errors.Wrap(wrapper, err) +} + +type rawCSR struct { + ID string `db:"id"` + SerialNumber string `db:"serial_number"` + CSR []byte `db:"csr"` + PrivateKey []byte `db:"private_key"` + EntityID string `db:"entity_id"` + Status string `db:"status"` + SubmittedAt time.Time `db:"submitted_at"` + SignedAt time.Time `db:"signed_at"` +} diff --git a/postgres/csr/init.go b/postgres/csr/init.go new file mode 100644 index 0000000..9ef8156 --- /dev/null +++ b/postgres/csr/init.go @@ -0,0 +1,34 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + _ "github.com/jackc/pgx/v5/stdlib" + migrate "github.com/rubenv/sql-migrate" +) + +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "csrs_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS csrs ( + id VARCHAR(36) PRIMARY KEY, + serial_number VARCHAR(40), + csr TEXT, + private_key TEXT, + entity_id VARCHAR(36), + status TEXT CHECK (status IN ('pending', 'signed', 'rejected')), + submitted_at TIMESTAMP, + signed_at TIMESTAMP + )`, + }, + Down: []string{ + "DROP TABLE csr", + }, + }, + }, + } +} diff --git a/sdk/certs_test.go b/sdk/certs_test.go index b905a1a..54f2f14 100644 --- a/sdk/certs_test.go +++ b/sdk/certs_test.go @@ -649,12 +649,12 @@ func TestDownloadCACert(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("GetChainCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) _, err := ctsdk.DownloadCA(tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + ok := svcCall.Parent.AssertCalled(t, "GetChainCA", mock.Anything, tc.token) assert.True(t, ok) } svcCall.Unset() @@ -709,12 +709,12 @@ func TestViewCA(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("GetChainCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) c, err := ctsdk.ViewCA(tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + ok := svcCall.Parent.AssertCalled(t, "GetChainCA", mock.Anything, tc.token) assert.True(t, ok) } assert.Equal(t, tc.sdkCert.Certificate, c.Certificate, fmt.Sprintf("expected: %v, got: %v", tc.sdkCert.Certificate, c.Certificate)) @@ -765,13 +765,13 @@ func TestGetCAToken(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("RetrieveCertDownloadToken", mock.Anything).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("RetrieveCAToken", mock.Anything).Return(tc.svcresp, tc.svcerr) resp, err := ctsdk.GetCAToken() assert.Equal(t, tc.err, err) if tc.err == nil { assert.Equal(t, tc.svcresp, resp.Token) - ok := svcCall.Parent.AssertCalled(t, "RetrieveCertDownloadToken", mock.Anything) + ok := svcCall.Parent.AssertCalled(t, "RetrieveCAToken", mock.Anything) assert.True(t, ok) } svcCall.Unset() diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 4410524..23be435 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -25,6 +25,65 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { return &MockSDK_Expecter{mock: &_m.Mock} } +// CreateCSR provides a mock function with given fields: pm, privKey +func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKey []byte) (sdk.CSR, errors.SDKError) { + ret := _m.Called(pm, privKey) + + if len(ret) == 0 { + panic("no return value specified for CreateCSR") + } + + var r0 sdk.CSR + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)); ok { + return rf(pm, privKey) + } + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) sdk.CSR); ok { + r0 = rf(pm, privKey) + } else { + r0 = ret.Get(0).(sdk.CSR) + } + + if rf, ok := ret.Get(1).(func(sdk.PageMetadata, []byte) errors.SDKError); ok { + r1 = rf(pm, privKey) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' +type MockSDK_CreateCSR_Call struct { + *mock.Call +} + +// CreateCSR is a helper method to define mock.On call +// - pm sdk.PageMetadata +// - privKey []byte +func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKey interface{}) *MockSDK_CreateCSR_Call { + return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKey)} +} + +func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKey []byte)) *MockSDK_CreateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(sdk.PageMetadata), args[1].([]byte)) + }) + return _c +} + +func (_c *MockSDK_CreateCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *MockSDK_CreateCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { + _c.Call.Return(run) + return _c +} + // DeleteCert provides a mock function with given fields: entityID func (_m *MockSDK) DeleteCert(entityID string) errors.SDKError { ret := _m.Called(entityID) @@ -308,6 +367,64 @@ func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string return _c } +// ListCSRs provides a mock function with given fields: pm +func (_m *MockSDK) ListCSRs(pm sdk.PageMetadata) (sdk.CSRPage, errors.SDKError) { + ret := _m.Called(pm) + + if len(ret) == 0 { + panic("no return value specified for ListCSRs") + } + + var r0 sdk.CSRPage + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(sdk.PageMetadata) (sdk.CSRPage, errors.SDKError)); ok { + return rf(pm) + } + if rf, ok := ret.Get(0).(func(sdk.PageMetadata) sdk.CSRPage); ok { + r0 = rf(pm) + } else { + r0 = ret.Get(0).(sdk.CSRPage) + } + + if rf, ok := ret.Get(1).(func(sdk.PageMetadata) errors.SDKError); ok { + r1 = rf(pm) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' +type MockSDK_ListCSRs_Call struct { + *mock.Call +} + +// ListCSRs is a helper method to define mock.On call +// - pm sdk.PageMetadata +func (_e *MockSDK_Expecter) ListCSRs(pm interface{}) *MockSDK_ListCSRs_Call { + return &MockSDK_ListCSRs_Call{Call: _e.mock.On("ListCSRs", pm)} +} + +func (_c *MockSDK_ListCSRs_Call) Run(run func(pm sdk.PageMetadata)) *MockSDK_ListCSRs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(sdk.PageMetadata)) + }) + return _c +} + +func (_c *MockSDK_ListCSRs_Call) Return(_a0 sdk.CSRPage, _a1 errors.SDKError) *MockSDK_ListCSRs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_ListCSRs_Call) RunAndReturn(run func(sdk.PageMetadata) (sdk.CSRPage, errors.SDKError)) *MockSDK_ListCSRs_Call { + _c.Call.Return(run) + return _c +} + // ListCerts provides a mock function with given fields: pm func (_m *MockSDK) ListCerts(pm sdk.PageMetadata) (sdk.CertificatePage, errors.SDKError) { ret := _m.Called(pm) @@ -473,6 +590,64 @@ func (_c *MockSDK_RenewCert_Call) RunAndReturn(run func(string) errors.SDKError) return _c } +// RetrieveCSR provides a mock function with given fields: csrID +func (_m *MockSDK) RetrieveCSR(csrID string) (sdk.CSR, errors.SDKError) { + ret := _m.Called(csrID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCSR") + } + + var r0 sdk.CSR + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(string) (sdk.CSR, errors.SDKError)); ok { + return rf(csrID) + } + if rf, ok := ret.Get(0).(func(string) sdk.CSR); ok { + r0 = rf(csrID) + } else { + r0 = ret.Get(0).(sdk.CSR) + } + + if rf, ok := ret.Get(1).(func(string) errors.SDKError); ok { + r1 = rf(csrID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' +type MockSDK_RetrieveCSR_Call struct { + *mock.Call +} + +// RetrieveCSR is a helper method to define mock.On call +// - csrID string +func (_e *MockSDK_Expecter) RetrieveCSR(csrID interface{}) *MockSDK_RetrieveCSR_Call { + return &MockSDK_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", csrID)} +} + +func (_c *MockSDK_RetrieveCSR_Call) Run(run func(csrID string)) *MockSDK_RetrieveCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockSDK_RetrieveCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *MockSDK_RetrieveCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_RetrieveCSR_Call) RunAndReturn(run func(string) (sdk.CSR, errors.SDKError)) *MockSDK_RetrieveCSR_Call { + _c.Call.Return(run) + return _c +} + // RetrieveCertDownloadToken provides a mock function with given fields: serialNumber func (_m *MockSDK) RetrieveCertDownloadToken(serialNumber string) (sdk.Token, errors.SDKError) { ret := _m.Called(serialNumber) @@ -579,6 +754,55 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError return _c } +// SignCSR provides a mock function with given fields: csrID, sign +func (_m *MockSDK) SignCSR(csrID string, sign bool) errors.SDKError { + ret := _m.Called(csrID, sign) + + if len(ret) == 0 { + panic("no return value specified for SignCSR") + } + + var r0 errors.SDKError + if rf, ok := ret.Get(0).(func(string, bool) errors.SDKError); ok { + r0 = rf(csrID, sign) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + + return r0 +} + +// MockSDK_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' +type MockSDK_SignCSR_Call struct { + *mock.Call +} + +// SignCSR is a helper method to define mock.On call +// - csrID string +// - sign bool +func (_e *MockSDK_Expecter) SignCSR(csrID interface{}, sign interface{}) *MockSDK_SignCSR_Call { + return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", csrID, sign)} +} + +func (_c *MockSDK_SignCSR_Call) Run(run func(csrID string, sign bool)) *MockSDK_SignCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(bool)) + }) + return _c +} + +func (_c *MockSDK_SignCSR_Call) Return(_a0 errors.SDKError) *MockSDK_SignCSR_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, bool) errors.SDKError) *MockSDK_SignCSR_Call { + _c.Call.Return(run) + return _c +} + // ViewCA provides a mock function with given fields: token func (_m *MockSDK) ViewCA(token string) (sdk.Certificate, errors.SDKError) { ret := _m.Called(token) diff --git a/sdk/sdk.go b/sdk/sdk.go index 9bf3ff3..6ca2fca 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -27,6 +27,7 @@ import ( const ( certsEndpoint = "certs" + csrEndpoint = "csrs" issueCertEndpoint = "certs/issue" emptyOCSPbody = 22 ) @@ -75,12 +76,24 @@ func (c CertStatus) MarshalJSON() ([]byte, error) { } type PageMetadata struct { - Total uint64 `json:"total,omitempty"` - Offset uint64 `json:"offset,omitempty"` - Limit uint64 `json:"limit,omitempty"` - EntityID string `json:"entity_id,omitempty"` - Token string `json:"token,omitempty"` - CommonName string `json:"common_name,omitempty"` + Total uint64 `json:"total"` + Offset uint64 `json:"offset,omitempty"` + Limit uint64 `json:"limit"` + EntityID string `json:"entity_id,omitempty"` + Token string `json:"token,omitempty"` + CommonName string `json:"common_name,omitempty"` + Organization []string `json:"organization,omitempty"` + OrganizationalUnit []string `json:"organizational_unit,omitempty"` + Country []string `json:"country,omitempty"` + Province []string `json:"province,omitempty"` + Locality []string `json:"locality,omitempty"` + StreetAddress []string `json:"street_address,omitempty"` + PostalCode []string `json:"postal_code,omitempty"` + DNSNames []string `json:"dns_names,omitempty"` + IPAddresses []string `json:"ip_addresses,omitempty"` + EmailAddresses []string `json:"email_addresses,omitempty"` + Status string `json:"status,omitempty"` + Sign bool `json:"sign,omitempty"` } type Options struct { @@ -148,6 +161,36 @@ type OCSPResponse struct { IssuerHash string `json:"issuer_hash,omitempty"` } +type CSRMetadata struct { + CommonName string `json:"common_name"` + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DNSNames []string `json:"dns_names"` + IPAddresses []string `json:"ip_addresses"` + EmailAddresses []string `json:"email_addresses"` +} + +type CSR struct { + ID string `json:"id,omitempty"` + CSR []byte `json:"csr,omitempty"` + PrivateKey []byte `json:"private_key,omitempty"` + EntityID string `json:"entity_id,omitempty"` + Status string `json:"status,omitempty"` + SubmittedAt time.Time `json:"submitted_at,omitempty"` + SignedAt time.Time `json:"signed_at,omitempty"` + SerialNumber string `json:"serial_number,omitempty"` +} + +type CSRPage struct { + PageMetadata + CSRs []CSR `json:"csrs,omitempty"` +} + type SDK interface { // IssueCert issues a certificate for a thing required for mTLS. // @@ -232,6 +275,33 @@ type SDK interface { // response, _ := sdk.GetCAToken() // fmt.Println(response) GetCAToken() (Token, errors.SDKError) + + // CreateCSR creates a new Certificate Signing Request + // + // example: + // pm = sdk.CSRMetadata{CommonName: "common_name", EntityID: "entity_id" } + // response, _ := sdk.CreateCSR(pm, []bytes("privKey")) + // fmt.Println(response) + CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) + + // SignCSR processes a pending CSR and either signs or rejects it + // + // example: + // err := sdk.SignCSR( "csr_id", "privKeyPath") + // fmt.Println(err) + SignCSR(csrID string, sign bool) errors.SDKError + + // RetrieveCSR retrieves a specific CSR by ID + // + // response, _ := sdk.RetrieveCSR("csr_id") + // fmt.Println(response) + RetrieveCSR(csrID string) (CSR, errors.SDKError) + + // ListCSRs returns a list of CSRs based on filter criteria + // + // response, _ := sdk.ListCSRs(sdk.PageMetadata{EntityID: "entity_id", Status: "pending"}) + // fmt.Println(response) + ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -500,6 +570,85 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { return tk, nil } +func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) { + r := csrReq{ + Organization: pm.Organization, + OrganizationalUnit: pm.OrganizationalUnit, + Country: pm.Country, + Province: pm.Province, + Locality: pm.Locality, + StreetAddress: pm.StreetAddress, + PostalCode: pm.PostalCode, + DNSNames: pm.DNSNames, + IPAddresses: pm.IPAddresses, + EmailAddresses: pm.EmailAddresses, + PrivateKey: privKey, + } + d, err := json.Marshal(r) + if err != nil { + return CSR{}, errors.NewSDKError(err) + } + url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, pm.EntityID) + _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) + if sdkerr != nil { + return CSR{}, sdkerr + } + + var csr CSR + if err := json.Unmarshal(body, &csr); err != nil { + return CSR{}, errors.NewSDKError(err) + } + return csr, nil +} + +func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { + pm := PageMetadata{ + Sign: sign, + } + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, csrID), pm) + if err != nil { + return errors.NewSDKError(err) + } + + _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return sdkerr + } + return nil +} + +func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s", certsEndpoint, csrEndpoint), pm) + if err != nil { + return CSRPage{}, errors.NewSDKError(err) + } + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return CSRPage{}, sdkerr + } + + var cp CSRPage + if err := json.Unmarshal(body, &cp); err != nil { + return CSRPage{}, errors.NewSDKError(err) + } + return cp, nil +} + +func (sdk mgSDK) RetrieveCSR(csrID string) (CSR, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, csrID) + + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusCreated) + if sdkerr != nil { + return CSR{}, sdkerr + } + + var csr CSR + if err := json.Unmarshal(body, &csr); err != nil { + return CSR{}, errors.NewSDKError(err) + } + return csr, nil +} + func NewSDK(conf Config) SDK { return &mgSDK{ certsURL: conf.CertsURL, @@ -586,6 +735,12 @@ func (pm PageMetadata) query() (string, error) { if pm.CommonName != "" { q.Add("common_name", pm.CommonName) } + if pm.Sign { + q.Add("status", "true") + } + if pm.Status != "" { + q.Add("status", pm.Status) + } return q.Encode(), nil } @@ -604,3 +759,17 @@ type certReq struct { TTL string `json:"ttl"` Options Options `json:"options"` } + +type csrReq struct { + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DNSNames []string `json:"dns_names"` + IPAddresses []string `json:"ip_addresses"` + EmailAddresses []string `json:"email_addresses"` + PrivateKey []byte `json:"private_key"` +} diff --git a/service.go b/service.go index c1dc667..a4ad34c 100644 --- a/service.go +++ b/service.go @@ -16,6 +16,7 @@ import ( "time" "github.com/absmach/certs/errors" + "github.com/absmach/certs/internal/uuid" "github.com/golang-jwt/jwt" "golang.org/x/crypto/ocsp" ) @@ -32,68 +33,6 @@ const ( downloadTokenExpiry = time.Minute * 5 ) -type CertType int - -const ( - RootCA CertType = iota - IntermediateCA - ClientCert -) - -const ( - Root = "RootCA" - Inter = "IntermediateCA" - Client = "ClientCert" - Unknown = "Unknown" -) - -func (c CertType) String() string { - switch c { - case RootCA: - return Root - case IntermediateCA: - return Inter - case ClientCert: - return Client - default: - return Unknown - } -} - -func CertTypeFromString(s string) (CertType, error) { - switch s { - case Root: - return RootCA, nil - case Inter: - return IntermediateCA, nil - case Client: - return ClientCert, nil - default: - return -1, errors.New("unknown cert type") - } -} - -type CA struct { - Type CertType - Certificate *x509.Certificate - PrivateKey *rsa.PrivateKey - SerialNumber string -} - -type Config struct { - CommonName string `yaml:"common_name"` - Organization []string `yaml:"organization"` - OrganizationalUnit []string `yaml:"organizational_unit"` - Country []string `yaml:"country"` - Province []string `yaml:"province"` - Locality []string `yaml:"locality"` - StreetAddress []string `yaml:"street_address"` - PostalCode []string `yaml:"postal_code"` - DNSNames []string `yaml:"dns_names"` - IPAddresses []net.IP `yaml:"ip_addresses"` - ValidityPeriod string `yaml:"validity_period"` -} - var ( serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 128) ErrNotFound = errors.New("entity not found") @@ -111,29 +50,22 @@ var ( ErrInvalidLength = errors.New("invalid length of serial numbers") ) -type SubjectOptions struct { - CommonName string - Organization []string `json:"organization"` - OrganizationalUnit []string `json:"organizational_unit"` - Country []string `json:"country"` - Province []string `json:"province"` - Locality []string `json:"locality"` - StreetAddress []string `json:"street_address"` - PostalCode []string `json:"postal_code"` -} - type service struct { repo Repository + csrRepo CSRRepository rootCA *CA intermediateCA *CA + idProvider uuid.IDProvider } var _ Service = (*service)(nil) -func NewService(ctx context.Context, repo Repository, config *Config) (Service, error) { +func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, config *Config, idp uuid.IDProvider) (Service, error) { var svc service svc.repo = repo + svc.csrRepo = csrRepo + svc.idProvider = idp if err := svc.loadCACerts(ctx); err != nil { return &svc, err } @@ -159,12 +91,18 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) { - privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - if err != nil { - return Certificate{}, err +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { + var privKey rsa.PrivateKey + var err error + if len(key) == 0 { + pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + privKey = *pKey + if err != nil { + return Certificate{}, err + } + } else { + privKey = *key[0] } - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return Certificate{}, err @@ -202,7 +140,7 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ return Certificate{}, err } dbCert := Certificate{ - Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), + Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(&privKey)}), Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), SerialNumber: template.SerialNumber.String(), EntityID: entityID, @@ -469,6 +407,147 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er return s.getConcatCAs(ctx) } +func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string, privateKey ...*rsa.PrivateKey) (CSR, error) { + var privKey *rsa.PrivateKey + var err error + + // Check if a private key is provided else generate a new private key. + if len(privateKey) > 0 && privateKey[0] != nil { + privKey = privateKey[0] + } else { + privKey, err = rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + } + + csrID, err := s.idProvider.ID() + if err != nil { + return CSR{}, err + } + + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: metadata.CommonName, + Organization: metadata.Organization, + OrganizationalUnit: metadata.OrganizationalUnit, + Country: metadata.Country, + Province: metadata.Province, + Locality: metadata.Locality, + StreetAddress: metadata.StreetAddress, + PostalCode: metadata.PostalCode, + }, + EmailAddresses: metadata.EmailAddresses, + DNSNames: metadata.DNSNames, + } + + for _, ip := range metadata.IPAddresses { + parsedIP := net.ParseIP(ip) + if parsedIP != nil { + template.IPAddresses = append(template.IPAddresses, parsedIP) + } + } + + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + csrPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + + privKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }) + + csr := CSR{ + ID: csrID, + CSR: csrPEM, + PrivateKey: privKeyPEM, + EntityID: entityID, + Status: Pending, + SubmittedAt: time.Now(), + } + + if err := s.csrRepo.CreateCSR(ctx, csr); err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + + return csr, nil +} + +func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error { + csr, err := s.csrRepo.RetrieveCSR(ctx, csrID) + if err != nil { + return errors.Wrap(ErrViewEntity, err) + } + + if csr.Status != Pending { + return ErrConflict + } + + if !approve { + csr.Status = Rejected + csr.SignedAt = time.Now() + return s.csrRepo.UpdateCSR(ctx, csr) + } + + block, _ := pem.Decode(csr.CSR) + if block == nil { + return errors.New("failed to parse CSR PEM") + } + + parsedCSR, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + if err := parsedCSR.CheckSignature(); err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + privKey, err := extractPrivateKey(csr.PrivateKey) + if err != nil { + return errors.Wrap(ErrMalformedEntity, err) + } + + cert, err := s.IssueCert(ctx, csr.EntityID, "", nil, SubjectOptions{ + CommonName: parsedCSR.Subject.CommonName, + Organization: parsedCSR.Subject.Organization, + OrganizationalUnit: parsedCSR.Subject.OrganizationalUnit, + Country: parsedCSR.Subject.Country, + Province: parsedCSR.Subject.Province, + Locality: parsedCSR.Subject.Locality, + StreetAddress: parsedCSR.Subject.StreetAddress, + PostalCode: parsedCSR.Subject.PostalCode, + }, privKey) + if err != nil { + return errors.Wrap(ErrCreateEntity, err) + } + + csr.Status = Signed + csr.SignedAt = time.Now() + csr.SerialNumber = cert.SerialNumber + + return s.csrRepo.UpdateCSR(ctx, csr) +} + +func (s *service) ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) { + cp, err := s.csrRepo.ListCSRs(ctx, pm) + if err != nil { + return CSRPage{}, errors.Wrap(ErrViewEntity, err) + } + + return cp, nil +} + +func (s *service) RetrieveCSR(ctx context.Context, csrID string) (CSR, error) { + return s.csrRepo.RetrieveCSR(ctx, csrID) +} + func (s *service) getConcatCAs(ctx context.Context) (Certificate, error) { intermediateCert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) if err != nil { @@ -791,3 +870,17 @@ func (s *service) loadCACerts(ctx context.Context) error { } return nil } + +func extractPrivateKey(pemKey []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemKey) + if block == nil { + return nil, errors.New("failed to parse private key PEM") + } + + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + return privKey, nil +} diff --git a/tracing/certs.go b/tracing/certs.go index efac814..4faeeb5 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -5,6 +5,7 @@ package tracing import ( "context" + "crypto/rsa" "crypto/x509" "github.com/absmach/certs" @@ -53,10 +54,10 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() - return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) + return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) } func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { @@ -100,3 +101,27 @@ func (tm *tracingMiddleware) GetChainCA(ctx context.Context, token string) (cert defer span.End() return tm.svc.GetChainCA(ctx, token) } + +func (tm *tracingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (certs.CSR, error) { + ctx, span := tm.tracer.Start(ctx, "create_csr") + defer span.End() + return tm.svc.CreateCSR(ctx, meta, entityID, key...) +} + +func (tm *tracingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { + ctx, span := tm.tracer.Start(ctx, "sign_csr") + defer span.End() + return tm.svc.SignCSR(ctx, csrID, approve) +} + +func (tm *tracingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + ctx, span := tm.tracer.Start(ctx, "list_csrs") + defer span.End() + return tm.svc.ListCSRs(ctx, pm) +} + +func (tm *tracingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { + ctx, span := tm.tracer.Start(ctx, "retrieve_csr") + defer span.End() + return tm.svc.RetrieveCSR(ctx, csrID) +}