diff --git a/teraserver/python/modules/DatabaseModule/DBManager.py b/teraserver/python/modules/DatabaseModule/DBManager.py index 6e1c612a..1bb9e14a 100755 --- a/teraserver/python/modules/DatabaseModule/DBManager.py +++ b/teraserver/python/modules/DatabaseModule/DBManager.py @@ -1,6 +1,6 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import joinedload -from sqlalchemy import event, inspect, update +from sqlalchemy import event, inspect, update, select from sqlalchemy.engine import Engine from sqlalchemy.engine.reflection import Inspector from sqlite3 import Connection as SQLite3Connection @@ -104,85 +104,89 @@ def setup_events_for_2fa_sites(self): def site_updated_or_inserted(mapper, connection, target: TeraSite): # Check if 2FA is enabled for this site if target and target.site_2fa_required: - # Efficiently load all related users with joinedload - service_roles = TeraServiceRole.query.options( - joinedload(TeraServiceRole.service_role_user_groups).joinedload( - TeraUserGroup.user_group_users - ) - ).filter(TeraServiceRole.id_site == target.id_site).all() - - # Get all users - user_ids = set() - for role in service_roles: - if role.service_role_user_groups: - for group in role.service_role_user_groups: - for user in group.user_group_users: - user_ids.add(user.id_user) - - # Perform a bulk update for all users at once - if user_ids: + # Get all users that have access to this site + users = TeraServiceAccess.query.join(TeraServiceRole, TeraServiceAccess.id_service_role == TeraServiceRole.id_service_role) \ + .join(TeraUserUserGroup, TeraServiceAccess.id_user_group == TeraUserUserGroup.id_user_group) \ + .join(TeraUser, TeraUserUserGroup.id_user == TeraUser.id_user) \ + .join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \ + .filter(TeraSite.id_site == target.id_site) \ + .with_entities(TeraUser).all() # Return the user information only + + # Enable 2FA for all users found + for user in users: connection.execute( update(TeraUser) - .where(TeraUser.id_user.in_(user_ids)) + .where(TeraUser.id_user == user.id_user) .values(user_2fa_enabled=True) ) + @event.listens_for(TeraUserGroup, 'after_update') @event.listens_for(TeraUserGroup, 'after_insert') def user_group_updated_or_inserted(mapper, connection, target: TeraUserGroup): - # Check if 2FA is enabled for a related site - if target and target.user_group_services_roles: - for role in target.user_group_services_roles: - if role.id_site and role.service_role_site.site_2fa_required: - # Efficiently load all related users with joinedload - user_ids = set() - for user in target.user_group_users: - user_ids.add(user.id_user) - - # Perform a bulk update for all users at once - if user_ids: - connection.execute( - update(TeraUser) - .where(TeraUser.id_user.in_(user_ids)) - .values(user_2fa_enabled=True) - ) + + # Check if 2FA is enabled for a related site in a single sql query + if target: + # Get users from the group that have access to a site with 2FA enabled + users = TeraUser.query.join(TeraUserUserGroup, TeraUser.id_user == TeraUserUserGroup.id_user) \ + .join(TeraServiceAccess, TeraUserUserGroup.id_user_group == TeraServiceAccess.id_user_group) \ + .join(TeraServiceRole, TeraServiceAccess.id_service_role == TeraServiceRole.id_service_role) \ + .join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \ + .filter(TeraUserUserGroup.id_user_group == target.id_user_group) \ + .filter(TeraSite.site_2fa_required == True) \ + .with_entities(TeraUser).all() # Return the user information only + + # Enable 2FA for all users found + for user in users: + connection.execute( + update(TeraUser) + .where(TeraUser.id_user == user.id_user) + .values(user_2fa_enabled=True) + ) @event.listens_for(TeraUserUserGroup, 'after_update') @event.listens_for(TeraUserUserGroup, 'after_insert') def user_user_group_updated_or_inserted(mapper, connection, target: TeraUserUserGroup): - # Check if 2FA is enabled for a related site - if target and target.user_user_group_user_group and target.user_user_group_user_group.user_group_services_roles: - for role in target.user_user_group_user_group.user_group_services_roles: - if role.id_site and role.service_role_site.site_2fa_required: + # If the user in the usergroup has access to a site with 2FA enabled, enable 2FA for the user + if target: + sites = TeraServiceAccess.query.join(TeraServiceRole, TeraServiceAccess.id_service_role == + TeraServiceRole.id_service_role) \ + .join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \ + .filter(TeraServiceAccess.id_user_group == target.id_user_group) \ + .with_entities(TeraSite).all() # Return the site information only + + for site in sites: + if site.site_2fa_required: # Perform single update for user connection.execute( update(TeraUser) - .where(TeraUser.id_user == target.user_user_group_user.id_user) + .where(TeraUser.id_user == target.id_user) .values(user_2fa_enabled=True) - ) + ) + break @event.listens_for(TeraUser, 'after_update') @event.listens_for(TeraUser, 'after_insert') def user_updated_or_inserted(mapper, connection, target: TeraUser): - # Check if 2FA is enabled for a related site - if target and target.user_user_groups: - for group in target.user_user_groups: - if group.user_group_services_roles: - for role in group.user_group_services_roles: - if role.id_site and role.service_role_site.site_2fa_required: - - otp_enabled = target.user_2fa_otp_enabled - - # Do not allow to change 2FA status if user has 2FA enabled - # and OTP set with secret - if target.user_2fa_otp_secret: - otp_enabled = True - - # Perform single update for user - connection.execute( - update(TeraUser) - .where(TeraUser.id_user == target.id_user) - .values(user_2fa_enabled=True, user_2fa_otp_enabled=otp_enabled) - ) + # Check if 2FA is enabled for a related site through user groups + if target: + sites = TeraServiceAccess.query.join(TeraUserUserGroup, TeraServiceAccess.id_user_group == TeraUserUserGroup.id_user_group) \ + .join(TeraServiceRole, TeraServiceAccess.id_service_role == TeraServiceRole.id_service_role) \ + .join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \ + .filter(TeraUserUserGroup.id_user == target.id_user) \ + .with_entities(TeraSite).all() # Return the site information only + + + for site in sites: + if site.site_2fa_required: + # Perform single update for user + connection.execute( + update(TeraUser) + .where(TeraUser.id_user == target.id_user) + .values(user_2fa_enabled=True) + ) + break + + def setup_events_for_class(self, cls, event_name): import json diff --git a/teraserver/python/tests/opentera/db/models/test_TeraSite.py b/teraserver/python/tests/opentera/db/models/test_TeraSite.py index d5921588..d5c5857e 100644 --- a/teraserver/python/tests/opentera/db/models/test_TeraSite.py +++ b/teraserver/python/tests/opentera/db/models/test_TeraSite.py @@ -10,7 +10,10 @@ from opentera.db.models.TeraSessionTypeSite import TeraSessionTypeSite from opentera.db.models.TeraTestTypeSite import TeraTestTypeSite from opentera.db.models.TeraDevice import TeraDevice - +from opentera.db.models.TeraUser import TeraUser +from opentera.db.models.TeraUserGroup import TeraUserGroup +from opentera.db.models.TeraUserUserGroup import TeraUserUserGroup +from opentera.db.models.TeraServiceAccess import TeraServiceAccess class TeraSiteTest(BaseModelsTest): @@ -233,9 +236,87 @@ def test_undelete(self): self.assertIsNotNone(TeraSessionTypeSite.get_session_type_site_by_id(id_session_type)) self.assertIsNotNone(TeraTestTypeSite.get_test_type_site_by_id(id_test_type)) + def test_2fa_required_site(self): + with self._flask_app.app_context(): + site = TeraSiteTest.new_test_site(name='2FA Site', site_2fa_required=True) + self.assertTrue(site.site_2fa_required) + self.db.session.add(site) + self.db.session.commit() + id_site = site.id_site + self.db.session.rollback() + same_site = TeraSite.get_site_by_id(id_site) + self.assertTrue(same_site.site_2fa_required) + + def test_enable_2fa_in_site_should_enable_in_users(self): + with self._flask_app.app_context(): + site = TeraSiteTest.new_test_site(name='2FA Site', site_2fa_required=False) + self.assertIsNotNone(site) + group = TeraSiteTest.new_test_user_group('Test Group', site.id_site) + self.assertIsNotNone(group) + user = TeraSiteTest.new_test_user('test_user', 'password', group.id_user_group) + self.assertIsNotNone(user) + + # Enable 2fa in site + site.site_2fa_required = True + self.db.session.add(site) + self.db.session.commit() + + # User should be updated automatically with 2fa + self.assertTrue(user.user_2fa_enabled) + + + + + + @staticmethod - def new_test_site(name: str = 'Test Site') -> TeraSite: + def new_test_site(name: str = 'Test Site', site_2fa_required: bool = False) -> TeraSite: site = TeraSite() site.site_name = name + site.site_2fa_required = site_2fa_required TeraSite.insert(site) return site + + @staticmethod + def new_test_user_group(name: str, id_site: int ) -> TeraUserGroup: + + # Create Service Role first + service_role = TeraServiceRole() + service_role.service_role_name = 'Test Site Role' + service_role.id_service = 1 # TeraServer by default + service_role.id_site = id_site + TeraServiceRole.insert(service_role) + + # Create User Group + group: TeraUserGroup = TeraUserGroup() + group.user_group_name = name + TeraUserGroup.insert(group) + + # Update Service Access + service_access = TeraServiceAccess() + service_access.id_service_role = service_role.id_service_role + service_access.id_user_group = group.id_user_group + TeraServiceAccess.insert(service_access) + + return group + + + @staticmethod + def new_test_user(username: str, password: str, id_user_group: int) -> TeraUser: + user = TeraUser() + user.user_username = username + user.user_password = password + user.user_firstname = username + user.user_lastname = username + user.user_email = f"{username}@test.com" + user.user_enabled = True + user.user_profile = {} + TeraUser.insert(user) + + # Update user group + user_user_group = TeraUserUserGroup() + user_user_group.id_user = user.id_user + user_user_group.id_user_group = id_user_group + TeraUserUserGroup.insert(user_user_group) + + return user