Skip to content

Commit

Permalink
Type hints for class based views (and other tools) (#2060)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardebeling authored Nov 20, 2023
1 parent 42c6df4 commit 419971c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 51 deletions.
77 changes: 43 additions & 34 deletions evap/evaluation/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable, Mapping
from typing import Any, TypeVar
from typing import Any, Protocol, TypeVar
from urllib.parse import quote

import xlwt
from django.conf import settings
from django.core.exceptions import SuspiciousOperation, ValidationError
from django.db.models import Model
from django.http import HttpResponse
from django.forms.formsets import BaseFormSet
from django.http import HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404
from django.utils.translation import get_language
from django.views.generic import FormView
Expand Down Expand Up @@ -42,7 +43,7 @@ def get_object_from_dict_pk_entry_or_logged_40x(model_cls: type[M], dict_obj: Ma
raise SuspiciousOperation from e


def is_prefetched(instance, attribute_name: str):
def is_prefetched(instance, attribute_name: str) -> bool:
"""
Is the given related attribute prefetched? Can be used to do ordering or counting in python and avoid additional
database queries
Expand All @@ -58,52 +59,52 @@ def is_prefetched(instance, attribute_name: str):
return False


def discard_cached_related_objects(instance):
def discard_cached_related_objects(instance: M) -> M:
"""
Discard all cached related objects (for ForeignKey and M2M Fields). Useful
if there were changes, but django's caching would still give us the old
values. Also useful for pickling objects without pickling the whole model
hierarchy (e.g. for storing instances in a cache)
"""
# Extracted from django's refresh_from_db, which sadly doesn't offer this part alone (without hitting the DB).
for field in instance._meta.concrete_fields:
for field in instance._meta.concrete_fields: # type: ignore
if field.is_relation and field.is_cached(instance):
field.delete_cached_value(instance)

for field in instance._meta.related_objects:
for field in instance._meta.related_objects: # type: ignore
if field.is_cached(instance):
field.delete_cached_value(instance)

instance._prefetched_objects_cache = {}
instance._prefetched_objects_cache = {} # type: ignore

return instance


def is_external_email(email):
def is_external_email(email: str) -> bool:
return not any(email.endswith("@" + domain) for domain in settings.INSTITUTION_EMAIL_DOMAINS)


def sort_formset(request, formset):
def sort_formset(request: HttpRequest, formset: BaseFormSet) -> None:
if request.POST: # if not, there will be no cleaned_data and the models should already be sorted anyways
formset.is_valid() # make sure all forms have cleaned_data
formset.forms.sort(key=lambda f: f.cleaned_data.get("order", 9001))


def date_to_datetime(date):
def date_to_datetime(date: datetime.date) -> datetime.datetime:
return datetime.datetime(year=date.year, month=date.month, day=date.day)


def vote_end_datetime(vote_end_date):
def vote_end_datetime(vote_end_date: datetime.date) -> datetime.datetime:
# The evaluation actually ends at EVALUATION_END_OFFSET_HOURS:00 of the day AFTER self.vote_end_date.
return date_to_datetime(vote_end_date) + datetime.timedelta(hours=24 + settings.EVALUATION_END_OFFSET_HOURS)


def get_parameter_from_url_or_session(request, parameter, default=False):
result = request.GET.get(parameter, None)
if result is None: # if no parameter is given take session value
def get_parameter_from_url_or_session(request: HttpRequest, parameter: str, default=False) -> bool:
result_str = request.GET.get(parameter, None)
if result_str is None: # if no parameter is given take session value
result = request.session.get(parameter, default)
else:
result = {"true": True, "false": False}.get(result.lower()) # convert parameter to boolean
result = {"true": True, "false": False}.get(result_str.lower()) # convert parameter to boolean
request.session[parameter] = result # store value for session
return result

Expand All @@ -115,7 +116,10 @@ def translate(**kwargs):
return property(lambda self: getattr(self, kwargs[get_language() or "en"]))


def clean_email(email):
EmailT = TypeVar("EmailT", str, None)


def clean_email(email: EmailT) -> EmailT:
if email:
email = email.strip().lower()
# Replace email domains in case there are multiple alias domains used in the organisation and all emails should
Expand All @@ -126,11 +130,11 @@ def clean_email(email):
return email


def capitalize_first(string):
def capitalize_first(string: str) -> str:
return string[0].upper() + string[1:]


def ilen(iterable):
def ilen(iterable: Iterable) -> int:
return sum(1 for _ in iterable)


Expand All @@ -148,7 +152,7 @@ class FormsetView(FormView):
def form_class(self):
return self.formset_class

def get_context_data(self, **kwargs):
def get_context_data(self, **kwargs) -> dict[str, Any]:
context = super().get_context_data(**kwargs)
context["formset"] = context.pop("form")
return context
Expand All @@ -157,19 +161,24 @@ def get_context_data(self, **kwargs):
# `get_formset_kwargs`. Users can thus override `get_formset_kwargs` instead. If it is not overridden, we delegate
# to the original `get_form_kwargs` instead. The same approach is used for the other renamed methods.

def get_form_kwargs(self):
def get_form_kwargs(self) -> dict:
return self.get_formset_kwargs()

def get_formset_kwargs(self):
def get_formset_kwargs(self) -> dict:
return super().get_form_kwargs()

def form_valid(self, form):
def form_valid(self, form) -> HttpResponse:
return self.formset_valid(form)

def formset_valid(self, formset):
def formset_valid(self, formset) -> HttpResponse:
return super().form_valid(formset)


class HasFormValid(Protocol):
def form_valid(self, form):
pass


class SaveValidFormMixin:
"""
Call `form.save()` if the submitted form is valid.
Expand All @@ -178,7 +187,7 @@ class SaveValidFormMixin:
example if a formset for a collection of objects is submitted.
"""

def form_valid(self, form):
def form_valid(self: HasFormValid, form) -> HttpResponse:
form.save()
return super().form_valid(form)

Expand All @@ -193,11 +202,11 @@ class AttachmentResponse(HttpResponse):
_to the response instance_ as if it was a writable file.
"""

def __init__(self, filename, content_type=None, **kwargs):
def __init__(self, filename: str, content_type=None, **kwargs) -> None:
super().__init__(content_type=content_type, **kwargs)
self.set_content_disposition(filename)

def set_content_disposition(self, filename):
def set_content_disposition(self, filename: str) -> None:
try:
filename.encode("ascii")
self["Content-Disposition"] = f'attachment; filename="{filename}"'
Expand All @@ -215,7 +224,7 @@ class HttpResponseNoContent(HttpResponse):

status_code = 204

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
del self["content-type"]

Expand Down Expand Up @@ -244,7 +253,7 @@ class ExcelExporter(ABC):
# have a sheet added at initialization.
default_sheet_name: str | None = None

def __init__(self):
def __init__(self) -> None:
self.workbook = xlwt.Workbook()
self.cur_row = 0
self.cur_col = 0
Expand All @@ -253,7 +262,7 @@ def __init__(self):
else:
self.cur_sheet = None

def write_cell(self, label="", style="default"):
def write_cell(self, label: str | None = "", style: str = "default") -> None:
"""Write a single cell and move to the next column."""
self.cur_sheet.write(
self.cur_row,
Expand All @@ -263,11 +272,11 @@ def write_cell(self, label="", style="default"):
)
self.cur_col += 1

def next_row(self):
def next_row(self) -> None:
self.cur_col = 0
self.cur_row += 1

def write_row(self, vals, style="default"):
def write_row(self, vals: Iterable[str], style: str = "default") -> None:
"""
Write a cell for every value and go to the next row.
Styling can be chosen
Expand All @@ -278,16 +287,16 @@ def write_row(self, vals, style="default"):
self.write_cell(val, style=style(val) if callable(style) else style)
self.next_row()

def write_empty_row_with_styles(self, styles):
def write_empty_row_with_styles(self, styles: Iterable[str]) -> None:
for style in styles:
self.write_cell(None, style)
self.next_row()

@abstractmethod
def export_impl(self, *args, **kwargs):
def export_impl(self, *args, **kwargs) -> None:
"""Specify the logic to insert the data into the sheet here."""

def export(self, response, *args, **kwargs):
def export(self, response, *args, **kwargs) -> None:
"""Convenience method to avoid some boilerplate."""
self.export_impl(*args, **kwargs)
self.workbook.save(response)
16 changes: 9 additions & 7 deletions evap/grades/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from django.conf import settings
from django.contrib import messages
from django.core.exceptions import PermissionDenied, SuspiciousOperation
Expand All @@ -23,7 +25,7 @@
class IndexView(TemplateView):
template_name = "grades_index.html"

def get_context_data(self, **kwargs):
def get_context_data(self, **kwargs) -> dict[str, Any]:
return super().get_context_data(**kwargs) | {
"semesters": Semester.objects.filter(grade_documents_are_deleted=False),
"disable_breadcrumb_grades": True,
Expand Down Expand Up @@ -51,19 +53,19 @@ class SemesterView(DetailView):

object: Semester

def get_object(self, *args, **kwargs):
def get_object(self, *args, **kwargs) -> Semester:
semester = super().get_object(*args, **kwargs)
if semester.grade_documents_are_deleted:
raise PermissionDenied
return semester

def get_context_data(self, **kwargs):
courses = (
def get_context_data(self, **kwargs) -> dict[str, Any]:
query = (
self.object.courses.filter(evaluations__wait_for_grade_upload_before_publishing=True)
.exclude(evaluations__state=Evaluation.State.NEW)
.distinct()
)
courses = course_grade_document_count_tuples(courses)
courses = course_grade_document_count_tuples(query)

return super().get_context_data(**kwargs) | {
"courses": courses,
Expand All @@ -77,13 +79,13 @@ class CourseView(DetailView):
model = Course
pk_url_kwarg = "course_id"

def get_object(self, *args, **kwargs):
def get_object(self, *args, **kwargs) -> Course:
course = super().get_object(*args, **kwargs)
if course.semester.grade_documents_are_deleted:
raise PermissionDenied
return course

def get_context_data(self, **kwargs):
def get_context_data(self, **kwargs) -> dict[str, Any]:
return super().get_context_data(**kwargs) | {
"semester": self.object.semester,
"grade_documents": self.object.grade_documents.all(),
Expand Down
24 changes: 14 additions & 10 deletions evap/staff/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from django.db import IntegrityError, transaction
from django.db.models import BooleanField, Case, Count, ExpressionWrapper, IntegerField, Prefetch, Q, Sum, When
from django.dispatch import receiver
from django.forms import formset_factory
from django.forms import BaseForm, formset_factory
from django.forms.models import inlineformset_factory, modelformset_factory
from django.http import Http404, HttpRequest, HttpResponse, HttpResponseBadRequest, HttpResponseRedirect
from django.shortcuts import get_object_or_404, redirect, render
Expand Down Expand Up @@ -580,7 +580,8 @@ class SemesterCreateView(SuccessMessageMixin, CreateView):
form_class = SemesterForm
success_message = gettext_lazy("Successfully created semester.")

def get_success_url(self):
def get_success_url(self) -> str:
assert self.object is not None
return reverse("staff:semester_view", args=[self.object.id])


Expand All @@ -592,7 +593,7 @@ class SemesterEditView(SuccessMessageMixin, UpdateView):
pk_url_kwarg = "semester_id"
success_message = gettext_lazy("Successfully updated semester.")

def get_success_url(self):
def get_success_url(self) -> str:
return reverse("staff:semester_view", args=[self.object.id])


Expand Down Expand Up @@ -1050,13 +1051,13 @@ class CourseEditView(SuccessMessageMixin, UpdateView):

object: Course

def get_object(self, *args, **kwargs):
def get_object(self, *args, **kwargs) -> Course:
course = super().get_object(*args, **kwargs)
if self.request.method == "POST" and not course.can_be_edited_by_manager:
raise SuspiciousOperation("Modifying this course is not allowed.")
return course

def get_context_data(self, **kwargs):
def get_context_data(self, **kwargs) -> dict[str, Any]:
context_data = super().get_context_data(**kwargs) | {
"semester": self.object.semester,
"editable": self.object.can_be_edited_by_manager,
Expand All @@ -1065,7 +1066,9 @@ def get_context_data(self, **kwargs):
context_data["course_form"] = context_data.pop("form")
return context_data

def form_valid(self, form):
def form_valid(self, form: BaseForm) -> HttpResponse:
assert isinstance(form, CourseForm) # https://www.github.com/typeddjango/django-stubs/issues/1809

if self.request.POST.get("operation") not in ("save", "save_create_evaluation", "save_create_single_result"):
raise SuspiciousOperation("Invalid POST operation")

Expand All @@ -1074,14 +1077,15 @@ def form_valid(self, form):
update_template_cache_of_published_evaluations_in_course(self.object)
return response

def get_success_url(self):
def get_success_url(self) -> str:
match self.request.POST["operation"]:
case "save":
return reverse("staff:semester_view", args=[self.object.semester.id])
case "save_create_evaluation":
return reverse("staff:evaluation_create_for_course", args=[self.object.id])
case "save_create_single_result":
return reverse("staff:single_result_create_for_course", args=[self.object.id])
raise SuspiciousOperation("Unexpected operation")


@require_POST
Expand Down Expand Up @@ -2290,7 +2294,7 @@ class UserMergeSelectionView(FormView):
form_class = UserMergeSelectionForm
template_name = "staff_user_merge_selection.html"

def form_valid(self, form):
def form_valid(self, form: UserMergeSelectionForm) -> HttpResponse:
return redirect(
"staff:user_merge",
form.cleaned_data["main_user"].id,
Expand Down Expand Up @@ -2334,7 +2338,7 @@ class TemplateEditView(SuccessMessageMixin, UpdateView):
success_url = reverse_lazy("staff:index")
template_name = "staff_template_form.html"

def get_context_data(self, **kwargs) -> dict:
def get_context_data(self, **kwargs) -> dict[str, Any]:
context = super().get_context_data(**kwargs)
template = context["template"] = context.pop("emailtemplate")

Expand Down Expand Up @@ -2377,7 +2381,7 @@ class FaqIndexView(SuccessMessageMixin, SaveValidFormMixin, FormsetView):
success_url = reverse_lazy("staff:faq_index")
success_message = gettext_lazy("Successfully updated the FAQ sections.")

def get_context_data(self, **kwargs):
def get_context_data(self, **kwargs) -> dict[str, Any]:
return super().get_context_data(**kwargs) | {"sections": FaqSection.objects.all()}


Expand Down

0 comments on commit 419971c

Please sign in to comment.