Skip to content

Commit

Permalink
Merge branch '2fa-dev' of https://github.com/introlab/opentera into 2…
Browse files Browse the repository at this point in the history
…fa-dev
  • Loading branch information
SBriere committed Oct 7, 2024
2 parents f8c9a3f + b7a8f8d commit 3ec3a63
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 63 deletions.
126 changes: 65 additions & 61 deletions teraserver/python/modules/DatabaseModule/DBManager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import joinedload
from sqlalchemy import event, inspect, update
from sqlalchemy import event, inspect, update, select
from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlite3 import Connection as SQLite3Connection
Expand Down Expand Up @@ -104,85 +104,89 @@ def setup_events_for_2fa_sites(self):
def site_updated_or_inserted(mapper, connection, target: TeraSite):
# Check if 2FA is enabled for this site
if target and target.site_2fa_required:
# Efficiently load all related users with joinedload
service_roles = TeraServiceRole.query.options(
joinedload(TeraServiceRole.service_role_user_groups).joinedload(
TeraUserGroup.user_group_users
)
).filter(TeraServiceRole.id_site == target.id_site).all()

# Get all users
user_ids = set()
for role in service_roles:
if role.service_role_user_groups:
for group in role.service_role_user_groups:
for user in group.user_group_users:
user_ids.add(user.id_user)

# Perform a bulk update for all users at once
if user_ids:
# Get all users that have access to this site
users = TeraServiceAccess.query.join(TeraServiceRole, TeraServiceAccess.id_service_role == TeraServiceRole.id_service_role) \
.join(TeraUserUserGroup, TeraServiceAccess.id_user_group == TeraUserUserGroup.id_user_group) \
.join(TeraUser, TeraUserUserGroup.id_user == TeraUser.id_user) \
.join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \
.filter(TeraSite.id_site == target.id_site) \
.with_entities(TeraUser).all() # Return the user information only

# Enable 2FA for all users found
for user in users:
connection.execute(
update(TeraUser)
.where(TeraUser.id_user.in_(user_ids))
.where(TeraUser.id_user == user.id_user)
.values(user_2fa_enabled=True)
)

@event.listens_for(TeraUserGroup, 'after_update')
@event.listens_for(TeraUserGroup, 'after_insert')
def user_group_updated_or_inserted(mapper, connection, target: TeraUserGroup):
# Check if 2FA is enabled for a related site
if target and target.user_group_services_roles:
for role in target.user_group_services_roles:
if role.id_site and role.service_role_site.site_2fa_required:
# Efficiently load all related users with joinedload
user_ids = set()
for user in target.user_group_users:
user_ids.add(user.id_user)

# Perform a bulk update for all users at once
if user_ids:
connection.execute(
update(TeraUser)
.where(TeraUser.id_user.in_(user_ids))
.values(user_2fa_enabled=True)
)

# Check if 2FA is enabled for a related site in a single sql query
if target:
# Get users from the group that have access to a site with 2FA enabled
users = TeraUser.query.join(TeraUserUserGroup, TeraUser.id_user == TeraUserUserGroup.id_user) \
.join(TeraServiceAccess, TeraUserUserGroup.id_user_group == TeraServiceAccess.id_user_group) \
.join(TeraServiceRole, TeraServiceAccess.id_service_role == TeraServiceRole.id_service_role) \
.join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \
.filter(TeraUserUserGroup.id_user_group == target.id_user_group) \
.filter(TeraSite.site_2fa_required == True) \
.with_entities(TeraUser).all() # Return the user information only

# Enable 2FA for all users found
for user in users:
connection.execute(
update(TeraUser)
.where(TeraUser.id_user == user.id_user)
.values(user_2fa_enabled=True)
)

@event.listens_for(TeraUserUserGroup, 'after_update')
@event.listens_for(TeraUserUserGroup, 'after_insert')
def user_user_group_updated_or_inserted(mapper, connection, target: TeraUserUserGroup):
# Check if 2FA is enabled for a related site
if target and target.user_user_group_user_group and target.user_user_group_user_group.user_group_services_roles:
for role in target.user_user_group_user_group.user_group_services_roles:
if role.id_site and role.service_role_site.site_2fa_required:
# If the user in the usergroup has access to a site with 2FA enabled, enable 2FA for the user
if target:
sites = TeraServiceAccess.query.join(TeraServiceRole, TeraServiceAccess.id_service_role ==
TeraServiceRole.id_service_role) \
.join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \
.filter(TeraServiceAccess.id_user_group == target.id_user_group) \
.with_entities(TeraSite).all() # Return the site information only

for site in sites:
if site.site_2fa_required:
# Perform single update for user
connection.execute(
update(TeraUser)
.where(TeraUser.id_user == target.user_user_group_user.id_user)
.where(TeraUser.id_user == target.id_user)
.values(user_2fa_enabled=True)
)
)
break

@event.listens_for(TeraUser, 'after_update')
@event.listens_for(TeraUser, 'after_insert')
def user_updated_or_inserted(mapper, connection, target: TeraUser):
# Check if 2FA is enabled for a related site
if target and target.user_user_groups:
for group in target.user_user_groups:
if group.user_group_services_roles:
for role in group.user_group_services_roles:
if role.id_site and role.service_role_site.site_2fa_required:

otp_enabled = target.user_2fa_otp_enabled

# Do not allow to change 2FA status if user has 2FA enabled
# and OTP set with secret
if target.user_2fa_otp_secret:
otp_enabled = True

# Perform single update for user
connection.execute(
update(TeraUser)
.where(TeraUser.id_user == target.id_user)
.values(user_2fa_enabled=True, user_2fa_otp_enabled=otp_enabled)
)
# Check if 2FA is enabled for a related site through user groups
if target:
sites = TeraServiceAccess.query.join(TeraUserUserGroup, TeraServiceAccess.id_user_group == TeraUserUserGroup.id_user_group) \
.join(TeraServiceRole, TeraServiceAccess.id_service_role == TeraServiceRole.id_service_role) \
.join(TeraSite, TeraServiceRole.id_site == TeraSite.id_site) \
.filter(TeraUserUserGroup.id_user == target.id_user) \
.with_entities(TeraSite).all() # Return the site information only


for site in sites:
if site.site_2fa_required:
# Perform single update for user
connection.execute(
update(TeraUser)
.where(TeraUser.id_user == target.id_user)
.values(user_2fa_enabled=True)
)
break



def setup_events_for_class(self, cls, event_name):
import json
Expand Down
85 changes: 83 additions & 2 deletions teraserver/python/tests/opentera/db/models/test_TeraSite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from opentera.db.models.TeraSessionTypeSite import TeraSessionTypeSite
from opentera.db.models.TeraTestTypeSite import TeraTestTypeSite
from opentera.db.models.TeraDevice import TeraDevice

from opentera.db.models.TeraUser import TeraUser
from opentera.db.models.TeraUserGroup import TeraUserGroup
from opentera.db.models.TeraUserUserGroup import TeraUserUserGroup
from opentera.db.models.TeraServiceAccess import TeraServiceAccess

class TeraSiteTest(BaseModelsTest):

Expand Down Expand Up @@ -233,9 +236,87 @@ def test_undelete(self):
self.assertIsNotNone(TeraSessionTypeSite.get_session_type_site_by_id(id_session_type))
self.assertIsNotNone(TeraTestTypeSite.get_test_type_site_by_id(id_test_type))

def test_2fa_required_site(self):
with self._flask_app.app_context():
site = TeraSiteTest.new_test_site(name='2FA Site', site_2fa_required=True)
self.assertTrue(site.site_2fa_required)
self.db.session.add(site)
self.db.session.commit()
id_site = site.id_site
self.db.session.rollback()
same_site = TeraSite.get_site_by_id(id_site)
self.assertTrue(same_site.site_2fa_required)

def test_enable_2fa_in_site_should_enable_in_users(self):
with self._flask_app.app_context():
site = TeraSiteTest.new_test_site(name='2FA Site', site_2fa_required=False)
self.assertIsNotNone(site)
group = TeraSiteTest.new_test_user_group('Test Group', site.id_site)
self.assertIsNotNone(group)
user = TeraSiteTest.new_test_user('test_user', 'password', group.id_user_group)
self.assertIsNotNone(user)

# Enable 2fa in site
site.site_2fa_required = True
self.db.session.add(site)
self.db.session.commit()

# User should be updated automatically with 2fa
self.assertTrue(user.user_2fa_enabled)






@staticmethod
def new_test_site(name: str = 'Test Site') -> TeraSite:
def new_test_site(name: str = 'Test Site', site_2fa_required: bool = False) -> TeraSite:
site = TeraSite()
site.site_name = name
site.site_2fa_required = site_2fa_required
TeraSite.insert(site)
return site

@staticmethod
def new_test_user_group(name: str, id_site: int ) -> TeraUserGroup:

# Create Service Role first
service_role = TeraServiceRole()
service_role.service_role_name = 'Test Site Role'
service_role.id_service = 1 # TeraServer by default
service_role.id_site = id_site
TeraServiceRole.insert(service_role)

# Create User Group
group: TeraUserGroup = TeraUserGroup()
group.user_group_name = name
TeraUserGroup.insert(group)

# Update Service Access
service_access = TeraServiceAccess()
service_access.id_service_role = service_role.id_service_role
service_access.id_user_group = group.id_user_group
TeraServiceAccess.insert(service_access)

return group


@staticmethod
def new_test_user(username: str, password: str, id_user_group: int) -> TeraUser:
user = TeraUser()
user.user_username = username
user.user_password = password
user.user_firstname = username
user.user_lastname = username
user.user_email = f"{username}@test.com"
user.user_enabled = True
user.user_profile = {}
TeraUser.insert(user)

# Update user group
user_user_group = TeraUserUserGroup()
user_user_group.id_user = user.id_user
user_user_group.id_user_group = id_user_group
TeraUserUserGroup.insert(user_user_group)

return user

0 comments on commit 3ec3a63

Please sign in to comment.