diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 8d0fe65..bcbf3ab 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -91,12 +91,19 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint { return issueCertRes{}, err } - serialNumber, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options) + cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options) if err != nil { return issueCertRes{}, err } - return issueCertRes{issued: true, SerialNumber: serialNumber}, nil + return issueCertRes{ + SerialNumber: cert.SerialNumber, + Certificate: string(cert.Certificate), + ExpiryTime: cert.ExpiryTime, + EntityID: cert.EntityID, + Revoked: cert.Revoked, + issued: true, + }, nil } } diff --git a/api/http/responses.go b/api/http/responses.go index 5635b0b..ed333bd 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -94,7 +94,11 @@ func (res downloadCertRes) Empty() bool { } type issueCertRes struct { - SerialNumber string `json:"serial_number"` + SerialNumber string `json:"serial_number"` + Certificate string `json:"certificate,omitempty"` + Revoked bool `json:"revoked"` + ExpiryTime time.Time `json:"expiry_time"` + EntityID string `json:"entity_id"` issued bool } diff --git a/api/logging.go b/api/logging.go index e2078bd..407f139 100644 --- a/api/logging.go +++ b/api/logging.go @@ -73,7 +73,7 @@ func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (serialNumber string, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { diff --git a/api/metrics.go b/api/metrics.go index ab64e87..663bbac 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -61,7 +61,7 @@ func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (string, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) diff --git a/certs.go b/certs.go index d8d68b8..4adb46b 100644 --- a/certs.go +++ b/certs.go @@ -52,7 +52,7 @@ type Service interface { RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (string, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) diff --git a/cli/certs.go b/cli/certs.go index ab5cbf9..9a16e2f 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -185,12 +185,12 @@ func NewCertsCmd() *cobra.Command { } } - serial, err := sdk.IssueCert(args[0], ttl, ipAddrs, option) + cert, err := sdk.IssueCert(args[0], ttl, ipAddrs, option) if err != nil { logErrorCmd(*cmd, err) return } - logJSONCmd(*cmd, serial) + logJSONCmd(*cmd, cert) }, } diff --git a/cli/certs_test.go b/cli/certs_test.go index e562b08..c81ca2f 100644 --- a/cli/certs_test.go +++ b/cli/certs_test.go @@ -45,14 +45,14 @@ func TestIssueCertCmd(t *testing.T) { ipAddrs := "[\"192.168.100.22\"]" - var sn sdk.SerialNumber + var cert sdk.Certificate cases := []struct { desc string args []string sdkErr errors.SDKError errLogMessage string logType outputLog - serial sdk.SerialNumber + cert sdk.Certificate }{ { desc: "issue cert successfully", @@ -62,7 +62,7 @@ func TestIssueCertCmd(t *testing.T) { ipAddrs, }, logType: entityLog, - serial: sdk.SerialNumber{SerialNumber: serialNumber}, + cert: sdk.Certificate{SerialNumber: serialNumber}, }, { desc: "issue cert with invalid args", @@ -92,19 +92,19 @@ func TestIssueCertCmd(t *testing.T) { "{\"organization\":[\"organization_name\"]}", }, logType: entityLog, - serial: sdk.SerialNumber{SerialNumber: serialNumber}, + cert: sdk.Certificate{SerialNumber: serialNumber}, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("IssueCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.serial, tc.sdkErr) + sdkCall := sdkMock.On("IssueCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.cert, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{issueCmd}, tc.args...)...) switch tc.logType { case entityLog: - err := json.Unmarshal([]byte(out), &sn) + err := json.Unmarshal([]byte(out), &cert) assert.Nil(t, err) - assert.Equal(t, tc.serial, sn, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.serial, sn)) + assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.cert, cert)) case errLog: assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) case usageLog: diff --git a/mocks/service.go b/mocks/service.go index 9eb3634..00bf73e 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -145,22 +145,22 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s } // IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (string, error) { +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error) { ret := _m.Called(ctx, entityID, ttl, ipAddrs, option) if len(ret) == 0 { panic("no return value specified for IssueCert") } - var r0 string + var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) (string, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)); ok { return rf(ctx, entityID, ttl, ipAddrs, option) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) string); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) certs.Certificate); ok { r0 = rf(ctx, entityID, ttl, ipAddrs, option) } else { - r0 = ret.Get(0).(string) + r0 = ret.Get(0).(certs.Certificate) } if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions) error); ok { @@ -194,12 +194,12 @@ func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID return _c } -func (_c *MockService_IssueCert_Call) Return(_a0 string, _a1 error) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) *MockService_IssueCert_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions) (string, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/sdk/certs_test.go b/sdk/certs_test.go index b106f29..b2a23db 100644 --- a/sdk/certs_test.go +++ b/sdk/certs_test.go @@ -55,7 +55,7 @@ func TestIssueCert(t *testing.T) { ttl string ipAddrs []string commonName string - svcresp string + svcresp certs.Certificate svcerr error err errors.SDKError sdkCert sdk.Certificate @@ -66,7 +66,9 @@ func TestIssueCert(t *testing.T) { ttl: ttl, ipAddrs: ipAddr, commonName: commonName, - svcresp: serialNum, + svcresp: certs.Certificate{ + SerialNumber: serialNum, + }, sdkCert: sdk.Certificate{ SerialNumber: serialNum, }, @@ -79,7 +81,7 @@ func TestIssueCert(t *testing.T) { ttl: ttl, ipAddrs: ipAddr, commonName: commonName, - svcresp: "", + svcresp: certs.Certificate{}, svcerr: certs.ErrCreateEntity, err: errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity), }, @@ -89,7 +91,7 @@ func TestIssueCert(t *testing.T) { ttl: ttl, ipAddrs: ipAddr, commonName: commonName, - svcresp: "", + svcresp: certs.Certificate{}, svcerr: certs.ErrMalformedEntity, err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), }, @@ -98,7 +100,7 @@ func TestIssueCert(t *testing.T) { entityID: id, ttl: ttl, commonName: commonName, - svcresp: serialNum, + svcresp: certs.Certificate{SerialNumber: serialNum}, sdkCert: sdk.Certificate{ SerialNumber: serialNum, }, @@ -111,7 +113,7 @@ func TestIssueCert(t *testing.T) { ttl: "", ipAddrs: ipAddr, commonName: commonName, - svcresp: serialNum, + svcresp: certs.Certificate{SerialNumber: serialNum}, sdkCert: sdk.Certificate{ SerialNumber: serialNum, }, @@ -124,7 +126,7 @@ func TestIssueCert(t *testing.T) { ttl: ttl, ipAddrs: ipAddr, commonName: "", - svcresp: "", + svcresp: certs.Certificate{}, svcerr: httpapi.ErrMissingCN, err: errors.NewSDKErrorWithStatus(httpapi.ErrMissingCN, http.StatusBadRequest), }, diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index e95c94d..b76a655 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -87,22 +87,22 @@ func (_c *MockSDK_DownloadCert_Call) RunAndReturn(run func(string, string) (sdk. } // IssueCert provides a mock function with given fields: entityID, ttl, ipAddrs, opts -func (_m *MockSDK) IssueCert(entityID string, ttl string, ipAddrs []string, opts sdk.Options) (sdk.SerialNumber, errors.SDKError) { +func (_m *MockSDK) IssueCert(entityID string, ttl string, ipAddrs []string, opts sdk.Options) (sdk.Certificate, errors.SDKError) { ret := _m.Called(entityID, ttl, ipAddrs, opts) if len(ret) == 0 { panic("no return value specified for IssueCert") } - var r0 sdk.SerialNumber + var r0 sdk.Certificate var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, []string, sdk.Options) (sdk.SerialNumber, errors.SDKError)); ok { + if rf, ok := ret.Get(0).(func(string, string, []string, sdk.Options) (sdk.Certificate, errors.SDKError)); ok { return rf(entityID, ttl, ipAddrs, opts) } - if rf, ok := ret.Get(0).(func(string, string, []string, sdk.Options) sdk.SerialNumber); ok { + if rf, ok := ret.Get(0).(func(string, string, []string, sdk.Options) sdk.Certificate); ok { r0 = rf(entityID, ttl, ipAddrs, opts) } else { - r0 = ret.Get(0).(sdk.SerialNumber) + r0 = ret.Get(0).(sdk.Certificate) } if rf, ok := ret.Get(1).(func(string, string, []string, sdk.Options) errors.SDKError); ok { @@ -137,12 +137,12 @@ func (_c *MockSDK_IssueCert_Call) Run(run func(entityID string, ttl string, ipAd return _c } -func (_c *MockSDK_IssueCert_Call) Return(_a0 sdk.SerialNumber, _a1 errors.SDKError) *MockSDK_IssueCert_Call { +func (_c *MockSDK_IssueCert_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) *MockSDK_IssueCert_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string, sdk.Options) (sdk.SerialNumber, errors.SDKError)) *MockSDK_IssueCert_Call { +func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string, sdk.Options) (sdk.Certificate, errors.SDKError)) *MockSDK_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/sdk/sdk.go b/sdk/sdk.go index 29500c0..0416b34 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -60,10 +60,6 @@ type Options struct { PostalCode []string `json:"postal_code"` } -type SerialNumber struct { - SerialNumber string `json:"serial_number"` -} - type Token struct { Token string `json:"token"` } @@ -113,9 +109,9 @@ type SDK interface { // IssueCert issues a certificate for a thing required for mTLS. // // example: - // serial , _ := sdk.IssueCert("entityID", "10h", []string{"ipAddr1", "ipAddr2"}, sdk.Options{CommonName: "commonName"}) - // fmt.Println(serial) - IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (SerialNumber, errors.SDKError) + // cert , _ := sdk.IssueCert("entityID", "10h", []string{"ipAddr1", "ipAddr2"}, sdk.Options{CommonName: "commonName"}) + // fmt.Println(cert) + IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) // DownloadCert returns a certificate given certificate ID // @@ -167,7 +163,7 @@ type SDK interface { OCSP(serialNumber string) (*ocsp.Response, errors.SDKError) } -func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (SerialNumber, errors.SDKError) { +func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { r := certReq{ IpAddrs: ipAddrs, TTL: ttl, @@ -175,24 +171,24 @@ func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) } d, err := json.Marshal(r) if err != nil { - return SerialNumber{}, errors.NewSDKError(err) + return Certificate{}, errors.NewSDKError(err) } url := fmt.Sprintf("%s/%s", issueCertEndpoint, entityID) url, err = sdk.withQueryParams(sdk.certsURL, url, PageMetadata{CommonName: opts.CommonName}) if err != nil { - return SerialNumber{}, errors.NewSDKError(err) + return Certificate{}, errors.NewSDKError(err) } _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusCreated) if sdkerr != nil { - return SerialNumber{}, sdkerr + return Certificate{}, sdkerr } - var sn SerialNumber - if err := json.Unmarshal(body, &sn); err != nil { - return SerialNumber{}, errors.NewSDKError(err) + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) } - return sn, nil + return cert, nil } func (sdk mgSDK) DownloadCert(token, serialNumber string) (CertificateBundle, errors.SDKError) { diff --git a/service.go b/service.go index 55b336d..079840f 100644 --- a/service.go +++ b/service.go @@ -150,19 +150,19 @@ func NewService(ctx context.Context, repo Repository) (Service, error) { // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (string, error) { +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) { privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) if err != nil { - return "", err + return Certificate{}, err } serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return "", err + return Certificate{}, err } if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil { - return "", ErrIntermediateCANotFound + return Certificate{}, ErrIntermediateCANotFound } // Parse the TTL if provided, otherwise use the default certValidityPeriod. @@ -170,7 +170,7 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ if ttl != "" { validity, err = time.ParseDuration(ttl) if err != nil { - return "", errors.Wrap(ErrMalformedEntity, err) + return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } } else { validity = certValidityPeriod @@ -191,7 +191,7 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.intermediateCA.Certificate, &privKey.PublicKey, s.intermediateCA.PrivateKey) if err != nil { - return "", err + return Certificate{}, err } dbCert := Certificate{ Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), @@ -202,10 +202,17 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ Type: ClientCert, } if err = s.repo.CreateCert(ctx, dbCert); err != nil { - return "", errors.Wrap(ErrCreateEntity, err) + return Certificate{}, errors.Wrap(ErrCreateEntity, err) } - return dbCert.SerialNumber, nil + return Certificate{ + Certificate: dbCert.Certificate, + SerialNumber: dbCert.SerialNumber, + EntityID: dbCert.EntityID, + ExpiryTime: dbCert.ExpiryTime, + Revoked: dbCert.Revoked, + Type: dbCert.Type, + }, nil } // RevokeCert revokes a certificate identified by its serial number. diff --git a/tracing/certs.go b/tracing/certs.go index 765b803..61779af 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -47,7 +47,7 @@ func (tm *tracingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return tm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (string, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)