Skip to content

Commit

Permalink
✨ Refresh & document OIDC support (#326)
Browse files Browse the repository at this point in the history
Parent issue: sequentech/meta#256
  • Loading branch information
edulix authored Nov 13, 2023
1 parent c105962 commit fc807ae
Show file tree
Hide file tree
Showing 17 changed files with 701 additions and 295 deletions.
18 changes: 18 additions & 0 deletions iam/api/migrations/0053_authapi_oidc_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Edulix on 2023-11-10 08:16

from django.db import migrations, models
from django.contrib.postgres.fields import JSONField

class Migration(migrations.Migration):
dependencies = [
('api', '0052_authapi_scheduled_events'),
]

operations = [
migrations.AddField(
model_name='authevent',
name='oidc_providers',
field=JSONField(blank=True, db_index=False, null=True),
preserve_default=False
)
]
145 changes: 131 additions & 14 deletions iam/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@

from contracts.base import check_contract
from contracts import CheckException
from marshmallow import Schema, fields as marshmallow_fields
from marshmallow import (
Schema,
fields as marshmallow_fields,
decorators
)
from marshmallow.exceptions import ValidationError as MarshMallowValidationError
from django.db.models import CharField
from django.db.models.functions import Length
Expand Down Expand Up @@ -335,13 +339,81 @@ class ScheduledEventsSchema(Schema):
)


class OIDCPPublicInfoSchema(Schema):
'''
Schema for an OIDC Provider Public Info Configuration
'''
id = marshmallow_fields.String(
required=True, allow_none=False
)
title = marshmallow_fields.String(
required=True, allow_none=False
)
description = marshmallow_fields.String(
required=True, allow_none=False
)
icon = marshmallow_fields.Url(
required=True, allow_none=False, schemes=["https"]
)
authorization_endpoint = marshmallow_fields.Url(
required=True, allow_none=False, schemes=["https"]
)
client_id = marshmallow_fields.String(
required=True, allow_none=False
)
scope = marshmallow_fields.String(
required=True, allow_none=False
)
issuer = marshmallow_fields.Url(
required=True, allow_none=False, schemes=["https"]
)
token_endpoint = marshmallow_fields.Url(
required=True, allow_none=False, schemes=["https"]
)
jwks_uri = marshmallow_fields.Url(
required=True, allow_none=False, schemes=["https"]
)
logout_uri = marshmallow_fields.Url(
required=True, allow_none=False, schemes=["https"]
)


class OIDCPPrivateInfoSchema(Schema):
'''
Schema for an OIDC Provider Private Info Configuration
'''
client_secret = marshmallow_fields.String(
required=True, allow_none=False
)

class OIDCProviderSchema(Schema):
'''
Schema for an OIDC Provider Configuration
'''
public_info = marshmallow_fields.Nested(
OIDCPPublicInfoSchema,
allow_none=False,
required=True
)
private_info = marshmallow_fields.Nested(
OIDCPPrivateInfoSchema,
allow_none=False,
required=True
)

@decorators.validates_schema(pass_original=True)
def validate_schema(self, data, original_data, **kwargs):
if not original_data:
raise ValidationError('The list must have at least one element.')


def get_schema_validator(klass):
'''
Given a Marshmallow Schema class, returns a Django field validator function
'''
def validator(data):
try:
klass().validate(data)
klass().validate(data=data)
except MarshMallowValidationError as error:
# Convert the marshmallow validation error to a Django one,
# because Django doesn't know how to handle marshmallow exceptions.
Expand Down Expand Up @@ -411,6 +483,34 @@ class AuthEvent(models.Model):
validators=[get_schema_validator(ScheduledEventsSchema)]
)

# the OIDC providers used in this election
# Example:
# [
#  {
#  "public_info": {
#  "id": "google",
#  "title": "Google",
#  "description": "Authenticate with Google",
#  "icon": "https://www.google.com/favicon.ico",
#  "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
#  "client_id": "<CLIENT_ID>.apps.googleusercontent.com",
#  "issuer": "https://accounts.google.com",
#  "token_endpoint": "https://oauth2.googleapis.com/token",
#  "jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
#  "logout_uri": "https://accounts.google.com/o/oauth2/v2/auth_logout"
#  },
#  "private_info": {
#  "client_secret": "<CLIENT_SECRET>"
#  }
#  }
# ]
oidc_providers = JSONField(
blank=True,
db_index=False,
null=True,
validators=[get_schema_validator(OIDCProviderSchema(many=True))]
)

# used by iam_celery to know what tallies to launch, and to serialize
# those launches one by one. set/get with (s|g)et_tally_status api calls
tally_status = models.CharField(
Expand Down Expand Up @@ -501,6 +601,19 @@ def check_allow_user_resend(self):
)
)

def get_public_config(self):
from authmethods import METHODS
base_config = {
'allow_user_resend': self.check_allow_user_resend()
}
if not hasattr(METHODS[self.auth_method], "get_public_config"):
return base_config
public_config = {
**base_config,
**METHODS[self.auth_method].get_public_config(self),
}
return public_config

def serialize(self, restrict=False):
'''
Used to serialize data when the user has priviledges to see all the data
Expand All @@ -510,6 +623,11 @@ def serialize(self, restrict=False):
# auth codes sent by authmethod
from authmethods.models import Code

def none_list(e):
if e is None:
return []
return e

d = {
'id': self.id,
'auth_method': self.auth_method,
Expand All @@ -527,23 +645,16 @@ def serialize(self, restrict=False):
'parent_id': self.parent.id if self.parent is not None else None,
'children_election_info': self.children_election_info,
'auth_method_config': {
'config': {
'allow_user_resend': self.check_allow_user_resend()
}
'config': self.get_public_config()
},
'scheduled_events': self.scheduled_events,
'openid_connect_providers': [
provider['public_info']
for provider in settings.OPENID_CONNECT_PROVIDERS_CONF
'oidc_providers': [
dict(public_info=provider["public_info"])
for provider in none_list(self.oidc_providers)
],
'support_otl_enabled': self.support_otl_enabled,
'inside_authenticate_otl_period': self.inside_authenticate_otl_period
}

def none_list(e):
if e is None:
return []
return e

def restrict_extra_fields(fields):
return [
Expand All @@ -552,6 +663,7 @@ def restrict_extra_fields(fields):
]

def restrict_alt_auth_methods():
from authmethods import patch_auth_event
if self.alternative_auth_methods is None:
return self.alternative_auth_methods

Expand All @@ -564,7 +676,12 @@ def restrict_alt_auth_methods():
),
public_name=alt_auth_method["public_name"],
public_name_i18n=alt_auth_method["public_name_i18n"],
icon=alt_auth_method["icon"]
icon=alt_auth_method["icon"],
auth_method_config=dict(
config=patch_auth_event(
self.id, alt_auth_method
).get_public_config()
)
)
for alt_auth_method in self.alternative_auth_methods
]
Expand Down
2 changes: 1 addition & 1 deletion iam/api/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ def test_get_authevent(self):
'hide_default_login_lookup_field': False,
'parent_id': None,
'children_election_info': None,
'openid_connect_providers': [],
'oidc_providers': [],
'scheduled_events': None,
'total_votes': 0,
'tally_status': 'notstarted',
Expand Down
58 changes: 41 additions & 17 deletions iam/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@
BallotBox,
TallySheet,
children_election_info_validator,
ScheduledEventsSchema
ScheduledEventsSchema,
OIDCProviderSchema
)

from .tasks import (
Expand Down Expand Up @@ -1568,7 +1569,7 @@ def post(request, pk=None):
})
config = req.get('auth_method_config', None)
if config:
msg += check_config(config, auth_method)
msg += check_config(config, auth_method, req)

extra_fields = req.get('extra_fields', None)
if extra_fields:
Expand All @@ -1579,17 +1580,23 @@ def post(request, pk=None):

alternative_auth_methods = req.get('alternative_auth_methods', None)
if alternative_auth_methods:
msg += check_alt_auth_methods(
alternative_auth_methods, extra_fields
)
msg += check_alt_auth_methods(req)
update_alt_methods_config(alternative_auth_methods)

scheduled_events = req.get('scheduled_events', None)
if scheduled_events:
try:
ScheduledEventsSchema().load(scheduled_events)
except MarshMallowValidationError as error:
msg += str(error.messages)
msg += "scheduled_events: " + str(error.messages)

oidc_providers = req.get('oidc_providers', None)
if oidc_providers:
try:
OIDCProviderSchema(many=True)\
.load(data=oidc_providers)
except MarshMallowValidationError as error:
msg += "oidc_providers: " + str(error.messages)

admin_fields = req.get('admin_fields', None)
if admin_fields:
Expand Down Expand Up @@ -1707,6 +1714,7 @@ def post(request, pk=None):
support_otl_enabled=support_otl_enabled,
alternative_auth_methods=alternative_auth_methods,
scheduled_events=scheduled_events,
oidc_providers=oidc_providers,
)
# If the election exists, we are doing an update. Else, we are
# doing an insert. We use this update method instead of just
Expand Down Expand Up @@ -1802,17 +1810,16 @@ def post(request, pk=None):

config = req.get('auth_method_config', None)
if config:
msg += check_config(config, auth_method)
msg += check_config(config, auth_method, req)

extra_fields = req.get('extra_fields', None)
if extra_fields:
msg += check_extra_fields(extra_fields)

alternative_auth_methods = req.get('alternative_auth_methods', None)
if alternative_auth_methods:
msg += check_alt_auth_methods(
alternative_auth_methods, extra_fields
)
msg += check_alt_auth_methods(req)
update_alt_methods_config(alternative_auth_methods)

if msg:
return json_response(status=400, message=msg)
Expand Down Expand Up @@ -2163,18 +2170,26 @@ def post(self, request, pk):

data = {'msg': 'Sent successful'}
# first, validate input
e = get_object_or_404(AuthEvent, pk=pk)
auth_event = get_object_or_404(AuthEvent, pk=pk)

try:
req = parse_json_request(request)
except:
return json_response(status=400, error_codename=ErrorCodes.BAD_REQUEST)
return json_response(
status=400, error_codename=ErrorCodes.BAD_REQUEST
)

userids = req.get("user-ids", None)
if userids is None:
permission_required(request.user, 'AuthEvent', ['edit', 'send-auth-all'], pk)
permission_required(
request.user, 'AuthEvent', ['edit', 'send-auth-all'], pk
)
extra_req = req.get('extra', {})
auth_method = req.get("auth-method", None)
auth_method = (
req.get("auth-method", None)
if req.get("auth-method", None)
else auth_event.auth_method
)
# force extra_req type to be a dict
if not isinstance(extra_req, dict):
return json_response(
Expand Down Expand Up @@ -2217,19 +2232,28 @@ def post(self, request, pk):
return json_response(data)

if config.get('msg', None) is not None:
if type(config.get('msg', '')) != str or len(config.get('msg', '')) > settings.MAX_AUTH_MSG_SIZE[e.auth_method]:
if (
type(config.get('msg', '')) != str or
len(config.get('msg', '')) > settings.MAX_AUTH_MSG_SIZE[auth_method]
):
return json_response(
status=400,
error_codename=ErrorCodes.BAD_REQUEST)

if config.get('html_message', None) is not None:
if type(config.get('html_message', '')) != str or len(config.get('html_message', '')) > settings.MAX_AUTH_MSG_SIZE[e.auth_method]:
if (
type(config.get('html_message', '')) != str or
len(config.get('html_message', '')) > settings.MAX_AUTH_MSG_SIZE[auth_method]
):
return json_response(
status=400,
error_codename=ErrorCodes.BAD_REQUEST)

if config.get('filter', None) is not None:
if type(config.get('filter', None)) != str or config.get('filter', None) not in ['voted', 'not_voted']:
if (
type(config.get('filter', None)) != str or
config.get('filter', None) not in ['voted', 'not_voted']
):
return json_response(
status=400,
error_codename=ErrorCodes.BAD_REQUEST)
Expand Down
Loading

0 comments on commit fc807ae

Please sign in to comment.