Skip to content

Commit

Permalink
Replace List type hint by list (Python 3.9+)
Browse files Browse the repository at this point in the history
  • Loading branch information
angely-dev committed Aug 22, 2024
1 parent a2bf975 commit 9171928
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
7 changes: 3 additions & 4 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pydantic import BaseModel
from pyfreeradius import User, Group, Nas
from pyfreeradius import UserRepository, GroupRepository, NasRepository
from typing import List

#
# We want our REST API endpoints to be KISS!
Expand Down Expand Up @@ -39,7 +38,7 @@ def read_root():
return {"Welcome!": f"API docs is available at {API_URL}/docs"}


@router.get("/nas", tags=["nas"], status_code=200, response_model=List[str])
@router.get("/nas", tags=["nas"], status_code=200, response_model=list[str])
def get_nases(response: Response, from_nasname: str | None = None):
nasnames = nas_repo.find_nasnames(from_nasname)
if nasnames:
Expand All @@ -48,7 +47,7 @@ def get_nases(response: Response, from_nasname: str | None = None):
return nasnames


@router.get("/users", tags=["users"], status_code=200, response_model=List[str])
@router.get("/users", tags=["users"], status_code=200, response_model=list[str])
def get_users(response: Response, from_username: str | None = None):
usernames = user_repo.find_usernames(from_username)
if usernames:
Expand All @@ -57,7 +56,7 @@ def get_users(response: Response, from_username: str | None = None):
return usernames


@router.get("/groups", tags=["groups"], status_code=200, response_model=List[str])
@router.get("/groups", tags=["groups"], status_code=200, response_model=list[str])
def get_groups(response: Response, from_groupname: str | None = None):
groupnames = group_repo.find_groupnames(from_groupname)
if groupnames:
Expand Down
37 changes: 18 additions & 19 deletions src/pyfreeradius.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pydantic import BaseModel, StringConstraints, Field, model_validator
from typing import List
from typing_extensions import Annotated

#
Expand Down Expand Up @@ -31,9 +30,9 @@ class GroupUser(BaseModel):

class User(BaseModel):
username: Annotated[str, StringConstraints(min_length=1)]
checks: List[AttributeOpValue] = []
replies: List[AttributeOpValue] = []
groups: List[UserGroup] = []
checks: list[AttributeOpValue] = []
replies: list[AttributeOpValue] = []
groups: list[UserGroup] = []

@model_validator(mode="after")
def check_fields_on_init(self):
Expand Down Expand Up @@ -71,9 +70,9 @@ def check_fields_on_init(self):

class Group(BaseModel):
groupname: Annotated[str, StringConstraints(min_length=1)]
checks: List[AttributeOpValue] = []
replies: List[AttributeOpValue] = []
users: List[GroupUser] = []
checks: list[AttributeOpValue] = []
replies: list[AttributeOpValue] = []
users: list[GroupUser] = []

@model_validator(mode="after")
def check_fields_on_init(self):
Expand Down Expand Up @@ -164,7 +163,7 @@ def exists(self, username: str) -> bool:
counts = [count for count, in db_cursor.fetchall()]
return sum(counts) > 0

def find_all_usernames(self) -> List[str]:
def find_all_usernames(self) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""SELECT DISTINCT username FROM {self.radcheck}
UNION SELECT DISTINCT username FROM {self.radreply}
Expand All @@ -173,12 +172,12 @@ def find_all_usernames(self) -> List[str]:
usernames = [username for username, in db_cursor.fetchall()]
return usernames

def find_usernames(self, from_username: str | None = None) -> List[str]:
def find_usernames(self, from_username: str | None = None) -> list[str]:
if not from_username:
return self._find_first_usernames()
return self._find_next_usernames(from_username)

def _find_first_usernames(self) -> List[str]:
def _find_first_usernames(self) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""
SELECT username FROM (
Expand All @@ -191,7 +190,7 @@ def _find_first_usernames(self) -> List[str]:
usernames = [username for username, in db_cursor.fetchall()]
return usernames

def _find_next_usernames(self, from_username: str) -> List[str]:
def _find_next_usernames(self, from_username: str) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""
SELECT username FROM (
Expand Down Expand Up @@ -257,7 +256,7 @@ def exists(self, groupname: str) -> bool:
counts = [count for count, in db_cursor.fetchall()]
return sum(counts) > 0

def find_all_groupnames(self) -> List[str]:
def find_all_groupnames(self) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""SELECT DISTINCT groupname FROM {self.radgroupcheck}
UNION SELECT DISTINCT groupname FROM {self.radgroupreply}
Expand All @@ -266,12 +265,12 @@ def find_all_groupnames(self) -> List[str]:
groupnames = [groupname for groupname, in db_cursor.fetchall()]
return groupnames

def find_groupnames(self, from_groupname: str | None = None) -> List[str]:
def find_groupnames(self, from_groupname: str | None = None) -> list[str]:
if not from_groupname:
return self._find_first_groupnames()
return self._find_next_groupnames(from_groupname)

def _find_first_groupnames(self) -> List[str]:
def _find_first_groupnames(self) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""
SELECT groupname FROM (
Expand All @@ -284,7 +283,7 @@ def _find_first_groupnames(self) -> List[str]:
groupnames = [groupname for groupname, in db_cursor.fetchall()]
return groupnames

def _find_next_groupnames(self, from_groupname: str) -> List[str]:
def _find_next_groupnames(self, from_groupname: str) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""
SELECT groupname FROM (
Expand Down Expand Up @@ -355,27 +354,27 @@ def exists(self, nasname: str) -> bool:
(count,) = db_cursor.fetchone()
return count > 0

def find_all_nasnames(self) -> List[str]:
def find_all_nasnames(self) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"SELECT DISTINCT nasname FROM {self.nas}"
db_cursor.execute(sql)
nasnames = [nasname for nasname, in db_cursor.fetchall()]
return nasnames

def find_nasnames(self, from_nasname: str | None = None) -> List[str]:
def find_nasnames(self, from_nasname: str | None = None) -> list[str]:
if not from_nasname:
return self._find_first_nasnames()
return self._find_next_nasnames(from_nasname)

def _find_first_nasnames(self) -> List[str]:
def _find_first_nasnames(self) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""SELECT DISTINCT nasname FROM {self.nas}
ORDER BY nasname LIMIT {self._PER_PAGE}"""
db_cursor.execute(sql)
nasnames = [nasname for nasname, in db_cursor.fetchall()]
return nasnames

def _find_next_nasnames(self, from_nasname: str) -> List[str]:
def _find_next_nasnames(self, from_nasname: str) -> list[str]:
with self._db_cursor() as db_cursor:
sql = f"""SELECT DISTINCT nasname FROM {self.nas}
WHERE nasname > %s ORDER BY nasname LIMIT {self._PER_PAGE}"""
Expand Down

0 comments on commit 9171928

Please sign in to comment.