Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ATO-983: Set InternalCommonSubjectIdentifier In AuthSession (Auth ICSID 1/3) #5396

Merged
1 commit merged into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import uk.gov.di.authentication.frontendapi.lambda.VerifyMfaCodeHandler;
import uk.gov.di.authentication.shared.entity.Session;
import uk.gov.di.authentication.shared.entity.AuthSessionItem;
import uk.gov.di.authentication.shared.entity.UserProfile;
import uk.gov.di.authentication.shared.helpers.ClientSubjectHelper;
import uk.gov.di.authentication.shared.services.AuthSessionService;
import uk.gov.di.authentication.shared.services.AuthenticationService;
import uk.gov.di.authentication.shared.services.ConfigurationService;
import uk.gov.di.authentication.shared.services.SessionService;
Expand All @@ -16,16 +17,18 @@ public class SessionHelper {

public static void updateSessionWithSubject(
UserContext userContext,
AuthenticationService authenticationService,
ConfigurationService configurationService,
AuthSessionItem authSession,
SessionService sessionService,
Session session) {
AuthSessionService authSessionService,
AuthenticationService authenticationService,
ConfigurationService configurationService) {
LOG.info("Calculating internal common subject identifier");
var session = userContext.getSession();
UserProfile userProfile =
userContext.getUserProfile().isPresent()
? userContext.getUserProfile().get()
: authenticationService.getUserProfileByEmail(session.getEmailAddress());
var internalCommonSubjectIdentifier =
var internalCommonSubjectId =
session.getInternalCommonSubjectIdentifier() != null
? session.getInternalCommonSubjectIdentifier()
: ClientSubjectHelper.getSubjectWithSectorIdentifier(
Expand All @@ -34,9 +37,9 @@ public static void updateSessionWithSubject(
authenticationService)
.getValue();
LOG.info("Setting internal common subject identifier in user session");
sessionService.storeOrUpdateSession(
userContext
.getSession()
.setInternalCommonSubjectIdentifier(internalCommonSubjectIdentifier));
session.setInternalCommonSubjectIdentifier(internalCommonSubjectId);
sessionService.storeOrUpdateSession(session);
authSession.setInternalCommonSubjectId(internalCommonSubjectId);
authSessionService.updateSession(authSession);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import uk.gov.di.authentication.frontendapi.entity.CheckUserExistsResponse;
import uk.gov.di.authentication.frontendapi.entity.LockoutInformation;
import uk.gov.di.authentication.shared.domain.AuditableEvent;
import uk.gov.di.authentication.shared.entity.AuthSessionItem;
import uk.gov.di.authentication.shared.entity.ErrorResponse;
import uk.gov.di.authentication.shared.entity.JourneyType;
import uk.gov.di.authentication.shared.entity.MFAMethodType;
Expand All @@ -22,6 +23,7 @@
import uk.gov.di.authentication.shared.lambda.BaseFrontendHandler;
import uk.gov.di.authentication.shared.serialization.Json.JsonException;
import uk.gov.di.authentication.shared.services.AuditService;
import uk.gov.di.authentication.shared.services.AuthSessionService;
import uk.gov.di.authentication.shared.services.AuthenticationService;
import uk.gov.di.authentication.shared.services.ClientService;
import uk.gov.di.authentication.shared.services.ClientSessionService;
Expand Down Expand Up @@ -50,10 +52,12 @@ public class CheckUserExistsHandler extends BaseFrontendHandler<CheckUserExistsR

private final AuditService auditService;
private final CodeStorageService codeStorageService;
private final AuthSessionService authSessionService;

public CheckUserExistsHandler(
ConfigurationService configurationService,
SessionService sessionService,
AuthSessionService authSessionService,
ClientSessionService clientSessionService,
ClientService clientService,
AuthenticationService authenticationService,
Expand All @@ -68,6 +72,7 @@ public CheckUserExistsHandler(
authenticationService);
this.auditService = auditService;
this.codeStorageService = codeStorageService;
this.authSessionService = authSessionService;
}

public CheckUserExistsHandler() {
Expand All @@ -78,13 +83,15 @@ public CheckUserExistsHandler(ConfigurationService configurationService) {
super(CheckUserExistsRequest.class, configurationService);
this.auditService = new AuditService(configurationService);
this.codeStorageService = new CodeStorageService(configurationService);
this.authSessionService = new AuthSessionService(configurationService);
}

public CheckUserExistsHandler(
ConfigurationService configurationService, RedisConnectionService redis) {
super(CheckUserExistsRequest.class, configurationService, redis);
this.auditService = new AuditService(configurationService);
this.codeStorageService = new CodeStorageService(configurationService, redis);
this.authSessionService = new AuthSessionService(configurationService);
}

@Override
Expand Down Expand Up @@ -150,6 +157,17 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
AuditableEvent auditableEvent;
var rpPairwiseId = AuditService.UNKNOWN;
var userMfaDetail = new UserMfaDetail();
var session = userContext.getSession();

var optionalAuthSession =
authSessionService.getSessionFromRequestHeaders(input.getHeaders());

if (optionalAuthSession.isEmpty()) {
return generateApiGatewayProxyErrorResponse(400, ErrorResponse.ERROR_1000);
}

AuthSessionItem authSession = optionalAuthSession.get();

if (userExists) {
auditableEvent = FrontendAuditableEvent.AUTH_CHECK_USER_KNOWN_EMAIL;
rpPairwiseId =
Expand All @@ -159,16 +177,17 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
authenticationService,
configurationService.getInternalSectorUri())
.getValue();
var internalPairwiseId =
var internalCommonSubjectId =
ClientSubjectHelper.getSubjectWithSectorIdentifier(
userProfile.get(),
configurationService.getInternalSectorUri(),
authenticationService)
.getValue();

LOG.info("Setting internal common subject identifier in user session");
userContext.getSession().setInternalCommonSubjectIdentifier(internalPairwiseId);

session.setInternalCommonSubjectIdentifier(internalCommonSubjectId);
authSession.setInternalCommonSubjectId(internalCommonSubjectId);
var isPhoneNumberVerified = userProfile.get().isPhoneNumberVerified();
var userCredentials =
authenticationService.getUserCredentialsFromEmail(emailAddress);
Expand All @@ -178,9 +197,10 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
userCredentials,
userProfile.get().getPhoneNumber(),
isPhoneNumberVerified);
auditContext = auditContext.withSubjectId(internalPairwiseId);
auditContext = auditContext.withSubjectId(internalCommonSubjectId);
} else {
userContext.getSession().setInternalCommonSubjectIdentifier(null);
session.setInternalCommonSubjectIdentifier(null);
authSession.setInternalCommonSubjectId(null);
auditableEvent = FrontendAuditableEvent.AUTH_CHECK_USER_NO_ACCOUNT_WITH_EMAIL;
}

Expand Down Expand Up @@ -212,7 +232,8 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
userMfaDetail.getMfaMethodType(),
getLastDigitsOfPhoneNumber(userMfaDetail),
lockoutInformation);
sessionService.storeOrUpdateSession(userContext.getSession());
sessionService.storeOrUpdateSession(session);
authSessionService.updateSession(authSession);

LOG.info("Successfully processed request");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
UserCredentials userCredentials = userContext.getUserCredentials().get();
auditContext = auditContext.withPhoneNumber(userProfile.getPhoneNumber());

var internalCommonSubjectIdentifier = getInternalCommonSubjectIdentifier(userProfile);
auditContext = auditContext.withUserId(internalCommonSubjectIdentifier);
var internalCommonSubjectId = getInternalCommonSubjectId(userProfile);
auditContext = auditContext.withUserId(internalCommonSubjectId);

if (isReauthJourneyWithFlagsEnabled(isReauthJourney)) {
var reauthCounts =
Expand Down Expand Up @@ -271,7 +271,7 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
return handleValidCredentials(
request,
userContext,
internalCommonSubjectIdentifier,
internalCommonSubjectId,
userCredentials,
userProfile,
auditContext,
Expand Down Expand Up @@ -308,7 +308,9 @@ private APIGatewayProxyResponseEvent handleValidCredentials(
.setInternalCommonSubjectIdentifier(internalCommonSubjectIdentifier));

authSessionService.updateSession(
authSessionItem.withAccountState(AuthSessionItem.AccountState.EXISTING));
authSessionItem
.withAccountState(AuthSessionItem.AccountState.EXISTING)
.withInternalCommonSubjectId(internalCommonSubjectIdentifier));

var userMfaDetail =
getUserMFADetail(
Expand Down Expand Up @@ -451,13 +453,13 @@ private int retrieveIncorrectPasswordCount(String email, boolean isReauthJourney
: codeStorageService.getIncorrectPasswordCount(email);
}

private String getInternalCommonSubjectIdentifier(UserProfile userProfile) {
var internalCommonSubjectIdentifier =
private String getInternalCommonSubjectId(UserProfile userProfile) {
var internalCommonSubjectId =
ClientSubjectHelper.getSubjectWithSectorIdentifier(
userProfile,
configurationService.getInternalSectorUri(),
authenticationService);
return internalCommonSubjectIdentifier.getValue();
return internalCommonSubjectId.getValue();
}

private boolean isTermsAndConditionsAccepted(UserContext userContext, UserProfile userProfile) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
LocalDateTime.now(ZoneId.of("UTC")).toString()));

LOG.info("Calculating internal common subject identifier");
var internalCommonSubjectIdentifier =
var internalCommonSubjectId =
ClientSubjectHelper.getSubjectWithSectorIdentifier(
user.getUserProfile(),
configurationService.getInternalSectorUri(),
authenticationService);
auditContext = auditContext.withSubjectId(internalCommonSubjectIdentifier.getValue());
auditContext = auditContext.withSubjectId(internalCommonSubjectId.getValue());

LOG.info("Calculating RP pairwise identifier");
var rpPairwiseId =
Expand Down Expand Up @@ -181,10 +181,12 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
.getSession()
.setEmailAddress(request.getEmail())
.setInternalCommonSubjectIdentifier(
internalCommonSubjectIdentifier.getValue()));
internalCommonSubjectId.getValue()));

authSessionService.updateSession(
authSessionItem.withAccountState(AuthSessionItem.AccountState.NEW));
authSessionItem
.withAccountState(AuthSessionItem.AccountState.NEW)
.withInternalCommonSubjectId(internalCommonSubjectId.getValue()));
LOG.info("Successfully processed request");
return generateApiGatewayProxyResponse(200, "");
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,11 @@ public APIGatewayProxyResponseEvent handleRequestWithUserContext(
if (codeRequestType.equals(CodeRequestType.PW_RESET_MFA_SMS)) {
SessionHelper.updateSessionWithSubject(
userContext,
authenticationService,
configurationService,
authSession.get(),
sessionService,
session);
authSessionService,
authenticationService,
configurationService);
}

processSuccessfulCodeRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,11 @@ private APIGatewayProxyResponseEvent verifyCode(
if (JourneyType.PASSWORD_RESET_MFA.equals(codeRequest.getJourneyType())) {
SessionHelper.updateSessionWithSubject(
userContext,
authenticationService,
configurationService,
authSession,
sessionService,
session);
authSessionService,
authenticationService,
configurationService);
}

var errorResponseMaybe = mfaCodeProcessor.validateCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import uk.gov.di.authentication.frontendapi.domain.FrontendAuditableEvent;
import uk.gov.di.authentication.frontendapi.entity.CheckUserExistsResponse;
import uk.gov.di.authentication.frontendapi.helpers.CommonTestVariables;
import uk.gov.di.authentication.shared.entity.AuthSessionItem;
import uk.gov.di.authentication.shared.entity.ClientRegistry;
import uk.gov.di.authentication.shared.entity.ClientSession;
import uk.gov.di.authentication.shared.entity.ErrorResponse;
Expand All @@ -34,6 +35,7 @@
import uk.gov.di.authentication.shared.helpers.NowHelper;
import uk.gov.di.authentication.shared.serialization.Json;
import uk.gov.di.authentication.shared.services.AuditService;
import uk.gov.di.authentication.shared.services.AuthSessionService;
import uk.gov.di.authentication.shared.services.AuthenticationService;
import uk.gov.di.authentication.shared.services.ClientService;
import uk.gov.di.authentication.shared.services.ClientSessionService;
Expand Down Expand Up @@ -87,13 +89,15 @@ class CheckUserExistsHandlerTest {
private final AuthenticationService authenticationService = mock(AuthenticationService.class);
private final AuditService auditService = mock(AuditService.class);
private final SessionService sessionService = mock(SessionService.class);
private final AuthSessionService authSessionService = mock(AuthSessionService.class);
private final ConfigurationService configurationService = mock(ConfigurationService.class);
private final ClientSessionService clientSessionService = mock(ClientSessionService.class);
private final ClientService clientService = mock(ClientService.class);
private final CodeStorageService codeStorageService = mock(CodeStorageService.class);
private CheckUserExistsHandler handler;
private static final Json objectMapper = SerializationService.getInstance();
private final Session session = mock(Session.class);
private final AuthSessionItem authSession = mock(AuthSessionItem.class);
private static final String CLIENT_ID = "test-client-id";
private static final String CLIENT_NAME = "test-client-name";
private static final Subject SUBJECT = new Subject();
Expand Down Expand Up @@ -135,6 +139,7 @@ void setup() {
new CheckUserExistsHandler(
configurationService,
sessionService,
authSessionService,
clientSessionService,
clientService,
authenticationService,
Expand All @@ -148,6 +153,7 @@ class WhenUserExists {
@BeforeEach
void setup() {
usingValidSession();
authSessionExists();
var userProfile =
generateUserProfile().withPhoneNumber(CommonTestVariables.UK_MOBILE_NUMBER);
setupUserProfileAndClient(Optional.of(userProfile));
Expand Down Expand Up @@ -286,6 +292,7 @@ void shouldReturn400AndSaveEmailInUserSessionIfUserAccountIsLocked() {
@Test
void shouldReturn200IfUserDoesNotExist() throws Json.JsonException {
usingValidSession();
authSessionExists();

setupUserProfileAndClient(Optional.empty());

Expand Down Expand Up @@ -342,11 +349,32 @@ void shouldReturn400IfEmailAddressIsInvalid() {
AUDIT_CONTEXT.withEmail("joe.bloggs"));
}

@Test
void shouldReturn400IfAuthSessionExpired() {
usingValidSession();
authSessionMissing();
setupClient();

var result = handler.handleRequest(userExistsRequest(EMAIL_ADDRESS), context);

assertThat(result, hasStatus(400));
assertThat(result, hasJsonBody(ErrorResponse.ERROR_1000));
}

private void usingValidSession() {
when(sessionService.getSessionFromRequestHeaders(anyMap()))
.thenReturn(Optional.of(session));
}

private void authSessionExists() {
when(authSessionService.getSessionFromRequestHeaders(any()))
.thenReturn(Optional.of(authSession));
}

private void authSessionMissing() {
when(authSessionService.getSessionFromRequestHeaders(any())).thenReturn(Optional.empty());
}

private UserProfile generateUserProfile() {
return new UserProfile()
.withEmail(EMAIL_ADDRESS)
Expand Down
Loading
Loading