Skip to content

Commit

Permalink
Merge pull request #1162 from alphagov/govsi-1161-fix-vtm-claim
Browse files Browse the repository at this point in the history
GOVSI-1161 - Fix empty vtm claim
  • Loading branch information
JHjava authored Dec 6, 2021
2 parents 959531b + 33fbcba commit c5eb5c0
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ public OIDCTokenResponse generateTokenResponse(
AccessToken accessToken =
generateAndStoreAccessToken(
clientID, internalSubject, scopesForToken, publicSubject);
AccessTokenHash accessTokenHash = AccessTokenHash.compute(accessToken, TOKEN_ALGORITHM);
AccessTokenHash accessTokenHash =
AccessTokenHash.compute(accessToken, TOKEN_ALGORITHM, null);
SignedJWT idToken =
generateIDToken(
clientID, publicSubject, additionalTokenClaims, accessTokenHash, vot);
Expand Down Expand Up @@ -256,7 +257,7 @@ private SignedJWT generateIDToken(
idTokenClaims.setAccessTokenHash(accessTokenHash);
idTokenClaims.putAll(additionalTokenClaims);
idTokenClaims.setClaim("vot", vot);
idTokenClaims.setClaim("vtm", trustMarkUri);
idTokenClaims.setClaim("vtm", trustMarkUri.toString());
try {
return generateSignedJWT(idTokenClaims.toJWTClaimsSet());
} catch (com.nimbusds.oauth2.sdk.ParseException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.nimbusds.openid.connect.sdk.Nonce;
import com.nimbusds.openid.connect.sdk.OIDCScopeValue;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.claims.AccessTokenHash;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import uk.gov.di.authentication.shared.entity.AccessTokenStore;
Expand Down Expand Up @@ -56,14 +57,14 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static uk.gov.di.authentication.shared.helpers.ConstructUriHelper.buildURI;

public class TokenServiceTest {

Expand All @@ -84,10 +85,12 @@ public class TokenServiceTest {
OIDCScopeValue.EMAIL,
OIDCScopeValue.PHONE,
OIDCScopeValue.OFFLINE_ACCESS);
private Nonce nonce;
private static final String CLIENT_ID = "client-id";
private static final String AUTH_CODE = new AuthorizationCode().toString();
private static final String REDIRECT_URI = "http://localhost/redirect";
private static final String BASE_URL = "http://example.com";
private static final String TOKEN_URI = "http://localhost/token";
private static final String BASE_URL = "https://example.com";
private static final String KEY_ID = "14342354354353";
private static final String REFRESH_TOKEN_PREFIX = "REFRESH_TOKEN:";
private static final String ACCESS_TOKEN_PREFIX = "ACCESS_TOKEN:";
Expand All @@ -99,12 +102,12 @@ public void setUp() {
when(configurationService.getAccessTokenExpiry()).thenReturn(300L);
when(configurationService.getIDTokenExpiry()).thenReturn(120L);
when(configurationService.getSessionExpiry()).thenReturn(300L);
nonce = new Nonce();
}

@Test
public void shouldGenerateTokenResponseWithRefreshToken()
throws ParseException, JOSEException, JsonProcessingException {
Nonce nonce = new Nonce();
when(configurationService.getTokenSigningKeyAlias()).thenReturn(KEY_ID);
createSignedIdToken();
createSignedAccessToken();
Expand All @@ -128,22 +131,9 @@ public void shouldGenerateTokenResponseWithRefreshToken()
LocalDateTime.now(ZoneId.of("UTC")).toString())),
false);

assertEquals(
BASE_URL, tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getIssuer());
assertEquals(
PUBLIC_SUBJECT.getValue(),
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("sub"));
assertSuccessfullTokenResponse(tokenResponse);

assertNotNull(tokenResponse.getOIDCTokens().getRefreshToken());
String accessTokenKey = ACCESS_TOKEN_PREFIX + CLIENT_ID + "." + PUBLIC_SUBJECT;
AccessTokenStore accessTokenStore =
new AccessTokenStore(
tokenResponse.getOIDCTokens().getAccessToken().getValue(),
INTERNAL_SUBJECT.getValue());
verify(redisConnectionService)
.saveWithExpiry(
accessTokenKey,
new ObjectMapper().writeValueAsString(accessTokenStore),
300L);
String refreshTokenKey = REFRESH_TOKEN_PREFIX + CLIENT_ID + "." + PUBLIC_SUBJECT;
RefreshTokenStore refreshTokenStore =
new RefreshTokenStore(
Expand All @@ -154,15 +144,11 @@ public void shouldGenerateTokenResponseWithRefreshToken()
refreshTokenKey,
new ObjectMapper().writeValueAsString(refreshTokenStore),
300L);
assertEquals(
nonce.getValue(),
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"));
}

@Test
public void shouldGenerateTokenResponseWithoutRefreshTokenWhenOfflineAccessScopeIsMissing()
throws ParseException, JOSEException, JsonProcessingException {
Nonce nonce = new Nonce();
when(configurationService.getTokenSigningKeyAlias()).thenReturn(KEY_ID);
when(configurationService.getAccessTokenExpiry()).thenReturn(300L);
createSignedIdToken();
Expand All @@ -186,70 +172,50 @@ public void shouldGenerateTokenResponseWithoutRefreshTokenWhenOfflineAccessScope
LocalDateTime.now(ZoneId.of("UTC")).toString())),
false);

assertEquals(
BASE_URL, tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getIssuer());
assertEquals(
PUBLIC_SUBJECT.getValue(),
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("sub"));
assertSuccessfullTokenResponse(tokenResponse);

assertNull(tokenResponse.getOIDCTokens().getRefreshToken());
String accessTokenKey = ACCESS_TOKEN_PREFIX + CLIENT_ID + "." + PUBLIC_SUBJECT;
AccessTokenStore accessTokenStore =
new AccessTokenStore(
tokenResponse.getOIDCTokens().getAccessToken().getValue(),
INTERNAL_SUBJECT.getValue());
verify(redisConnectionService)
.saveWithExpiry(
accessTokenKey,
new ObjectMapper().writeValueAsString(accessTokenStore),
300L);
assertEquals(
nonce.getValue(),
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"));
}

@Test
public void shouldSuccessfullyValidatePrivateKeyJWT() throws JOSEException, ParseException {
public void shouldSuccessfullyValidatePrivateKeyJWT() throws JOSEException {
KeyPair keyPair = generateRsaKeyPair();
String publicKey = Base64.getMimeEncoder().encodeToString(keyPair.getPublic().getEncoded());
LocalDateTime localDateTime = LocalDateTime.now().plusMinutes(5);
Date expiryDate = Date.from(localDateTime.atZone(ZoneId.of("UTC")).toInstant());
String requestParams = generateSerialisedPrivateKeyJWT(keyPair, expiryDate.getTime());
assertThat(
tokenService.validatePrivateKeyJWT(
requestParams, publicKey, "http://localhost/token", CLIENT_ID),
tokenService.validatePrivateKeyJWT(requestParams, publicKey, TOKEN_URI, CLIENT_ID),
equalTo(Optional.empty()));
}

@Test
public void shouldFailToValidatePrivateKeyJWTIfExpired() throws JOSEException, ParseException {
public void shouldFailToValidatePrivateKeyJWTIfExpired() throws JOSEException {
KeyPair keyPair = generateRsaKeyPair();
String publicKey = Base64.getMimeEncoder().encodeToString(keyPair.getPublic().getEncoded());
LocalDateTime localDateTime = LocalDateTime.now().minusMinutes(2);
Date expiryDate = Date.from(localDateTime.atZone(ZoneId.of("UTC")).toInstant());
String requestParams = generateSerialisedPrivateKeyJWT(keyPair, expiryDate.getTime());
assertThat(
tokenService.validatePrivateKeyJWT(
requestParams, publicKey, "http://localhost/token", CLIENT_ID),
tokenService.validatePrivateKeyJWT(requestParams, publicKey, TOKEN_URI, CLIENT_ID),
equalTo(Optional.of(OAuth2Error.INVALID_GRANT)));
}

@Test
public void shouldFailToValidatePrivateKeyJWTIfInvalidClientId()
throws JOSEException, ParseException {
public void shouldFailToValidatePrivateKeyJWTIfInvalidClientId() throws JOSEException {
KeyPair keyPair = generateRsaKeyPair();
String publicKey = Base64.getMimeEncoder().encodeToString(keyPair.getPublic().getEncoded());
LocalDateTime localDateTime = LocalDateTime.now().plusMinutes(5);
Date expiryDate = Date.from(localDateTime.atZone(ZoneId.of("UTC")).toInstant());
String requestParams = generateSerialisedPrivateKeyJWT(keyPair, expiryDate.getTime());
assertThat(
tokenService.validatePrivateKeyJWT(
requestParams, publicKey, "http://localhost/token", "wrong-client-id"),
requestParams, publicKey, TOKEN_URI, "wrong-client-id"),
equalTo(Optional.of(OAuth2Error.INVALID_CLIENT)));
}

@Test
public void shouldReturnErrorIfUnableToValidatePrivateKeyJWTSignature()
throws JOSEException, ParseException {
public void shouldReturnErrorIfUnableToValidatePrivateKeyJWTSignature() throws JOSEException {
KeyPair keyPair = generateRsaKeyPair();
KeyPair keyPairTwo = generateRsaKeyPair();
String publicKey =
Expand All @@ -258,8 +224,7 @@ public void shouldReturnErrorIfUnableToValidatePrivateKeyJWTSignature()
Date expiryDate = Date.from(localDateTime.atZone(ZoneId.of("UTC")).toInstant());
String requestParams = generateSerialisedPrivateKeyJWT(keyPair, expiryDate.getTime());
assertThat(
tokenService.validatePrivateKeyJWT(
requestParams, publicKey, "http://localhost/token", CLIENT_ID),
tokenService.validatePrivateKeyJWT(requestParams, publicKey, TOKEN_URI, CLIENT_ID),
equalTo(Optional.of(OAuth2Error.INVALID_CLIENT)));
}

Expand Down Expand Up @@ -401,8 +366,7 @@ public void shouldReturnErrorWhenValidatingRefreshTokenRequestWithWrongGrant() {
private String generateSerialisedPrivateKeyJWT(KeyPair keyPair, long expiryTime)
throws JOSEException {
JWTAuthenticationClaimsSet claimsSet =
new JWTAuthenticationClaimsSet(
new ClientID(CLIENT_ID), new Audience("http://localhost/token"));
new JWTAuthenticationClaimsSet(new ClientID(CLIENT_ID), new Audience(TOKEN_URI));
claimsSet.getExpirationTime().setTime(expiryTime);
PrivateKeyJWT privateKeyJWT =
new PrivateKeyJWT(
Expand Down Expand Up @@ -473,4 +437,42 @@ private KeyPair generateRsaKeyPair() {
kpg.initialize(2048);
return kpg.generateKeyPair();
}

private void assertSuccessfullTokenResponse(OIDCTokenResponse tokenResponse)
throws ParseException, JsonProcessingException {
String accessTokenKey = ACCESS_TOKEN_PREFIX + CLIENT_ID + "." + PUBLIC_SUBJECT;
assertNotNull(tokenResponse.getOIDCTokens().getAccessToken());
AccessTokenStore accessTokenStore =
new AccessTokenStore(
tokenResponse.getOIDCTokens().getAccessToken().getValue(),
INTERNAL_SUBJECT.getValue());
verify(redisConnectionService)
.saveWithExpiry(
accessTokenKey,
new ObjectMapper().writeValueAsString(accessTokenStore),
300L);
assertThat(
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaims().size(),
equalTo(9));
assertThat(
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("sub"),
equalTo(PUBLIC_SUBJECT.getValue()));
assertThat(
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"),
equalTo(nonce.getValue()));
assertThat(
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("vtm"),
equalTo(buildURI(BASE_URL, "/trustmark").toString()));
assertThat(
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getIssuer(),
equalTo(BASE_URL));
assertThat(
tokenResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("at_hash"),
equalTo(
AccessTokenHash.compute(
tokenResponse.getOIDCTokens().getAccessToken(),
JWSAlgorithm.ES256,
null)
.toString()));
}
}

0 comments on commit c5eb5c0

Please sign in to comment.