Skip to content

Commit

Permalink
NOISSUE - Add CA retrieve option (#24)
Browse files Browse the repository at this point in the history
* Add retrieve CA option

Signed-off-by: nyagamunene <[email protected]>

* Fix failing linter

Signed-off-by: nyagamunene <[email protected]>

* Fix failing linter

Signed-off-by: nyagamunene <[email protected]>

* Address comments

Signed-off-by: nyagamunene <[email protected]>

---------

Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene authored Oct 9, 2024
1 parent 0ea242d commit 6509013
Show file tree
Hide file tree
Showing 16 changed files with 1,075 additions and 46 deletions.
53 changes: 52 additions & 1 deletion api/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func downloadCertEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(downloadReq)
if err := req.validate(); err != nil {
return downloadCertRes{}, err
return fileDownloadRes{}, err
}
cert, ca, err := svc.RetrieveCert(ctx, req.token, req.id)
if err != nil {
Expand Down Expand Up @@ -243,3 +243,54 @@ func generateCRLEndpoint(svc certs.Service) endpoint.Endpoint {
}, nil
}
}

func getDownloadCATokenEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
token, err := svc.RetrieveCAToken(ctx)
if err != nil {
return requestCertDownloadTokenRes{}, err
}

return requestCertDownloadTokenRes{Token: token}, nil
}
}

func downloadCAEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(downloadReq)
if err := req.validate(); err != nil {
return fileDownloadRes{}, err
}

cert, err := svc.GetSigningCA(ctx, req.token)
if err != nil {
return fileDownloadRes{}, err
}

return fileDownloadRes{
Certificate: cert.Certificate,
PrivateKey: cert.Key,
Filename: "ca.zip",
ContentType: "application/zip",
}, nil
}
}

func viewCAEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(downloadReq)
if err := req.validate(); err != nil {
return viewCertRes{}, err
}

cert, err := svc.GetSigningCA(ctx, req.token)
if err != nil {
return viewCertRes{}, err
}

return viewCertRes{
Certificate: string(cert.Certificate),
Key: string(cert.Key),
}, nil
}
}
3 changes: 0 additions & 3 deletions api/http/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ type downloadReq struct {
}

func (req downloadReq) validate() error {
if req.id == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrEmptySerialNo)
}
if req.token == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrEmptyToken)
}
Expand Down
28 changes: 5 additions & 23 deletions api/http/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,6 @@ func (res requestCertDownloadTokenRes) Empty() bool {
return false
}

type downloadCertRes struct {
Certificate []byte `json:"certificate"`
PrivateKey []byte `json:"private_key"`
CA []byte `json:"ca"`
}

func (res downloadCertRes) Code() int {
return http.StatusOK
}

func (res downloadCertRes) Headers() map[string]string {
return map[string]string{}
}

func (res downloadCertRes) Empty() bool {
return false
}

type issueCertRes struct {
SerialNumber string `json:"serial_number"`
Certificate string `json:"certificate,omitempty"`
Expand Down Expand Up @@ -138,12 +120,12 @@ func (res listCertsRes) Empty() bool {
}

type viewCertRes struct {
SerialNumber string `json:"serial_number"`
SerialNumber string `json:"serial_number,omitempty"`
Certificate string `json:"certificate,omitempty"`
Key string `json:"key,omitempty"`
Revoked bool `json:"revoked"`
ExpiryTime time.Time `json:"expiry_time"`
EntityID string `json:"entity_id"`
Key string `json:"key,omitempty,omitempty"`
Revoked bool `json:"revoked,omitempty"`
ExpiryTime time.Time `json:"expiry_time,omitempty"`
EntityID string `json:"entity_id,omitempty"`
}

func (res viewCertRes) Code() int {
Expand Down
65 changes: 65 additions & 0 deletions api/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http
EncodeResponse,
opts...,
), "generate_crl").ServeHTTP)
r.Get("/get-ca/token", otelhttp.NewHandler(kithttp.NewServer(
getDownloadCATokenEndpoint(svc),
decodeView,
EncodeResponse,
opts...,
), "get_ca_token").ServeHTTP)
r.Get("/view-ca", otelhttp.NewHandler(kithttp.NewServer(
viewCAEndpoint(svc),
decodeDownloadCA,
EncodeResponse,
opts...,
), "view_ca").ServeHTTP)
r.Get("/download-ca", otelhttp.NewHandler(kithttp.NewServer(
downloadCAEndpoint(svc),
decodeDownloadCA,
encodeCADownloadResponse,
opts...,
), "download_ca").ServeHTTP)
})

r.Get("/health", certs.Health("certs", instanceID))
Expand Down Expand Up @@ -139,6 +157,18 @@ func decodeDownloadCerts(_ context.Context, r *http.Request) (interface{}, error
return req, nil
}

func decodeDownloadCA(_ context.Context, r *http.Request) (interface{}, error) {
token, err := readStringQuery(r, token, "")
if err != nil {
return nil, err
}
req := downloadReq{
token: token,
}

return req, nil
}

func decodeOCSPRequest(_ context.Context, r *http.Request) (interface{}, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
Expand Down Expand Up @@ -280,6 +310,41 @@ func encodeFileDownloadResponse(_ context.Context, w http.ResponseWriter, respon
return err
}

func encodeCADownloadResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
resp := response.(fileDownloadRes)
var buffer bytes.Buffer
zw := zip.NewWriter(&buffer)

f, err := zw.Create("ca.crt")
if err != nil {
return err
}

if _, err = f.Write(resp.Certificate); err != nil {
return err
}

f, err = zw.Create("ca.key")
if err != nil {
return err
}

if _, err = f.Write(resp.PrivateKey); err != nil {
return err
}

if err := zw.Close(); err != nil {
return err
}

w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", resp.Filename))
w.Header().Set("Content-Type", resp.ContentType)

_, err = w.Write(buffer.Bytes())

return err
}

// loggingErrorEncoder is a go-kit error encoder logging decorator.
func loggingErrorEncoder(logger *slog.Logger, enc kithttp.ErrorEncoder) kithttp.ErrorEncoder {
return func(ctx context.Context, err error, w http.ResponseWriter) {
Expand Down
26 changes: 25 additions & 1 deletion api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (lm *loggingMiddleware) RevokeCert(ctx context.Context, serialNumber string

func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (tokenString string, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method get_cert_download_token for cert %s took %s to complete", serialNumber, time.Since(begin))
message := fmt.Sprintf("Method get_cert_download_token for cert took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
Expand All @@ -73,6 +73,18 @@ func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri
return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber)
}

func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString string, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method get_cert_download_token for cert took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.RetrieveCAToken(ctx)
}

func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin))
Expand Down Expand Up @@ -144,3 +156,15 @@ func (lm *loggingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT
}(time.Now())
return lm.svc.GenerateCRL(ctx, caType)
}

func (lm *loggingMiddleware) GetSigningCA(ctx context.Context, token string) (cert certs.Certificate, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method get_signing_ca took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.GetSigningCA(ctx, token)
}
19 changes: 19 additions & 0 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,19 @@ func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri
mm.counter.With("method", "get_certificate_download_token").Add(1)
mm.latency.With("method", "get_certificate_download_token").Observe(time.Since(begin).Seconds())
}(time.Now())

return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber)
}

func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error) {
defer func(begin time.Time) {
mm.counter.With("method", "get_CA_token").Add(1)
mm.latency.With("method", "get_CA_token").Observe(time.Since(begin).Seconds())
}(time.Now())

return mm.svc.RetrieveCAToken(ctx)
}

func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) {
defer func(begin time.Time) {
mm.counter.With("method", "issue_certificate").Add(1)
Expand All @@ -82,6 +92,7 @@ func (mm *metricsMiddleware) ViewCert(ctx context.Context, serialNumber string)
mm.counter.With("method", "view_certificate").Add(1)
mm.latency.With("method", "view_certificate").Observe(time.Since(begin).Seconds())
}(time.Now())

return mm.svc.ViewCert(ctx, serialNumber)
}

Expand All @@ -108,3 +119,11 @@ func (mm *metricsMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT
}(time.Now())
return mm.svc.GenerateCRL(ctx, caType)
}

func (mm *metricsMiddleware) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) {
defer func(begin time.Time) {
mm.counter.With("method", "get_signing_ca").Add(1)
mm.latency.With("method", "get_signing_ca").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.GetSigningCA(ctx, token)
}
14 changes: 11 additions & 3 deletions certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,22 @@ type Service interface {
RevokeCert(ctx context.Context, serialNumber string) error

// RetrieveCert retrieves a certificate record from the database.
RetrieveCert(ctx context.Context, token string, serialNumber string) (Certificate, []byte, error)
RetrieveCert(ctx context.Context, token, serialNumber string) (Certificate, []byte, error)

// ViewCert retrieves a certificate record from the database.
ViewCert(ctx context.Context, serialNumber string) (Certificate, error)

// ListCerts retrieves the certificates from the database while applying filters.
ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error)

// RetrieveCertDownloadToken retrieves a certificate download token.
// RetrieveCertDownloadToken generates a certificate download token.
// The token is needed to download the client certificate.
RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error)

// RetrieveCAToken generates a CA download and view token.
// The token is needed to view and download the CA certificate.
RetrieveCAToken(ctx context.Context) (string, error)

// IssueCert issues a certificate from the database.
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error)

Expand All @@ -60,8 +65,11 @@ type Service interface {
// GetEntityID retrieves the entity ID for a certificate.
GetEntityID(ctx context.Context, serialNumber string) (string, error)

// GenerateCRL creates
// GenerateCRL creates cert revocation list.
GenerateCRL(ctx context.Context, caType CertType) ([]byte, error)

// Retrieves the signing CA.
GetSigningCA(ctx context.Context, token string) (Certificate, error)
}

type Repository interface {
Expand Down
53 changes: 52 additions & 1 deletion cli/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ var cmdCerts = []cobra.Command{
},
},
{
Use: "view <serial_number> ",
Use: "view <serial_number>",
Short: "View certificate",
Long: `Views a certificate for a given serial number.`,
Run: func(cmd *cobra.Command, args []string) {
Expand All @@ -155,6 +155,57 @@ var cmdCerts = []cobra.Command{
logJSONCmd(*cmd, cert)
},
},
{
Use: "view-ca <token>",
Short: "View-ca certificate",
Long: `Views ca certificate key with a given token.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
logUsageCmd(*cmd, cmd.Use)
return
}
cert, err := sdk.ViewCA(args[0])
if err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, cert)
},
},
{
Use: "download-ca <token>",
Short: "Download signing CA",
Long: `Download intermediate cert and ca with a given token.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
logUsageCmd(*cmd, cmd.Use)
return
}
bundle, err := sdk.DownloadCA(args[0])
if err != nil {
logErrorCmd(*cmd, err)
return
}
logSaveCAFiles(*cmd, bundle)
},
},
{
Use: "token-ca",
Short: "Get CA token",
Long: `Gets a download token for CA.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 0 {
logUsageCmd(*cmd, cmd.Use)
return
}
token, err := sdk.GetCAToken()
if err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, token)
},
},
}

// NewCertsCmd returns certificate command.
Expand Down
Loading

0 comments on commit 6509013

Please sign in to comment.