From fbc8ac53b2b3f9ff3a33df9da710063a71c46974 Mon Sep 17 00:00:00 2001 From: Patryk Kalinowski Date: Mon, 19 Feb 2024 19:05:44 +0100 Subject: [PATCH] rpc(tests): add test for normalized issuer in Identity --- rpc/helpers_test.go | 31 +++++++++++++++++++++++++------ rpc/sessions_test.go | 27 ++++++++++++++++++++------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/rpc/helpers_test.go b/rpc/helpers_test.go index 38c3c484..16106602 100644 --- a/rpc/helpers_test.go +++ b/rpc/helpers_test.go @@ -68,7 +68,7 @@ func getTestingCtxValue(ctx context.Context, k string) string { func initRPC(cfg *config.Config, enc *enclave.Enclave, dbClient *dbMock) *rpc.RPC { svc := &rpc.RPC{ Config: cfg, - HTTPClient: http.DefaultClient, + HTTPClient: httpClient{}, Enclave: enc, Wallets: newWalletServiceMock(nil), Tenants: data.NewTenantTable(dbClient, "Tenants"), @@ -166,7 +166,7 @@ QwIDAQAB } } -func issueAccessTokenAndRunJwksServer(t *testing.T, optTokenBuilderFn ...func(*jwt.Builder)) (iss string, tok string, close func()) { +func issueAccessTokenAndRunJwksServer(t *testing.T, optTokenBuilderFn ...func(*jwt.Builder, string)) (iss string, tok string, close func()) { jwtKeyRaw, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) jwtKey, err := jwk.FromRaw(jwtKeyRaw) @@ -204,7 +204,7 @@ func issueAccessTokenAndRunJwksServer(t *testing.T, optTokenBuilderFn ...func(*j Subject("subject") if len(optTokenBuilderFn) > 0 && optTokenBuilderFn[0] != nil { - optTokenBuilderFn[0](tokBuilder) + optTokenBuilderFn[0](tokBuilder, jwksServer.URL) } tokRaw, err := tokBuilder.Build() @@ -459,9 +459,12 @@ func newTenant(t *testing.T, enc *enclave.Enclave, issuer string) (*data.Tenant, }, UpgradeCode: "CHANGEME", WaasAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJwYXJ0bmVyX2lkIjozfQ.g2fWwLrKPhTUpLFc7ZM9pMm4kEHGu8haCMzMOOGiqSM", - OIDCProviders: []*proto.OpenIdProvider{{Issuer: issuer, Audience: []string{"audience"}}}, - AllowedOrigins: []string{"http://localhost"}, - KMSKeys: []string{"SessionKey"}, + OIDCProviders: []*proto.OpenIdProvider{ + {Issuer: issuer, Audience: []string{"audience"}}, + {Issuer: "https://" + strings.TrimPrefix(issuer, "http://"), Audience: []string{"audience"}}, + }, + AllowedOrigins: []string{"http://localhost"}, + KMSKeys: []string{"SessionKey"}, } encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(context.Background(), att, "TenantKey", payload) @@ -748,3 +751,19 @@ func (w walletServiceMock) FinishValidateSession(ctx context.Context, sessionId } var _ proto_wallet.WaaS = (*walletServiceMock)(nil) + +type httpClient struct{} + +func (httpClient) Do(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" + return http.DefaultClient.Do(req) +} + +func (httpClient) Get(s string) (*http.Response, error) { + if strings.HasPrefix(s, "https://") { + s = "http://" + strings.TrimPrefix(s, "https://") + } + return http.DefaultClient.Get(s) +} + +var _ rpc.HTTPClient = (*httpClient)(nil) diff --git a/rpc/sessions_test.go b/rpc/sessions_test.go index fb7c0594..e8e0ad3b 100644 --- a/rpc/sessions_test.go +++ b/rpc/sessions_test.go @@ -8,6 +8,7 @@ import ( mathrand "math/rand" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -41,7 +42,7 @@ func TestRPC_RegisterSession(t *testing.T) { } testCases := map[string]struct { assertFn func(t *testing.T, sess *proto.Session, err error, p assertionParams) - tokBuilderFn func(b *jwt.Builder) + tokBuilderFn func(b *jwt.Builder, url string) intentBuilderFn func(t *testing.T, data intents.IntentDataOpenSession) *proto.Intent }{ "Basic": { @@ -60,14 +61,14 @@ func TestRPC_RegisterSession(t *testing.T) { }, }, "WithInvalidIssuer": { - tokBuilderFn: func(b *jwt.Builder) { b.Issuer("https://id.example.com") }, + tokBuilderFn: func(b *jwt.Builder, url string) { b.Issuer("https://id.example.com") }, assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) { require.Nil(t, sess) require.ErrorContains(t, err, `issuer "https://id.example.com" not valid for this tenant`) }, }, "WithValidNonce": { - tokBuilderFn: func(b *jwt.Builder) { b.Claim("nonce", sessHash) }, + tokBuilderFn: func(b *jwt.Builder, url string) { b.Claim("nonce", sessHash) }, assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) { require.NoError(t, err) require.NotNil(t, sess) @@ -76,14 +77,14 @@ func TestRPC_RegisterSession(t *testing.T) { }, }, "WithInvalidNonce": { - tokBuilderFn: func(b *jwt.Builder) { b.Claim("nonce", "0x1234567890abcdef") }, + tokBuilderFn: func(b *jwt.Builder, url string) { b.Claim("nonce", "0x1234567890abcdef") }, assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) { require.Nil(t, sess) require.ErrorContains(t, err, "JWT validation: nonce not satisfied") }, }, "WithInvalidNonceButValidSessionAddressClaim": { - tokBuilderFn: func(b *jwt.Builder) { + tokBuilderFn: func(b *jwt.Builder, url string) { b.Claim("nonce", "0x1234567890abcdef"). Claim("sequence:session_hash", sessHash) }, @@ -95,7 +96,7 @@ func TestRPC_RegisterSession(t *testing.T) { }, }, "WithVerifiedEmail": { - tokBuilderFn: func(b *jwt.Builder) { + tokBuilderFn: func(b *jwt.Builder, url string) { b.Claim("email", "user@example.com").Claim("email_verified", "true") }, assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) { @@ -106,7 +107,7 @@ func TestRPC_RegisterSession(t *testing.T) { }, }, "WithUnverifiedEmail": { - tokBuilderFn: func(b *jwt.Builder) { + tokBuilderFn: func(b *jwt.Builder, url string) { b.Claim("email", "user@example.com").Claim("email_verified", "false") }, assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) { @@ -131,6 +132,18 @@ func TestRPC_RegisterSession(t *testing.T) { assert.ErrorContains(t, err, "intent is invalid: no signatures") }, }, + "IssuerMissingScheme": { + tokBuilderFn: func(b *jwt.Builder, url string) { + b.Issuer(strings.TrimPrefix(url, "http://")) + }, + assertFn: func(t *testing.T, sess *proto.Session, err error, p assertionParams) { + require.NoError(t, err) + require.NotNil(t, sess) + + httpsIssuer := "https://" + strings.TrimPrefix(p.issuer, "http://") + assert.Equal(t, httpsIssuer, sess.Identity.Issuer) + }, + }, } for label, testCase := range testCases {