Skip to content

Commit

Permalink
added docstrings and cleaned the codes
Browse files Browse the repository at this point in the history
  • Loading branch information
deepmancer committed Aug 13, 2024
1 parent 899d689 commit 4bb666f
Show file tree
Hide file tree
Showing 22 changed files with 1,011 additions and 173 deletions.
141 changes: 127 additions & 14 deletions fastapi_auth_jwt/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,33 @@
import jwt
from pydantic import BaseModel

from .config import AuthConfig, StorageConfig, User
from .config.auth_token import AuthConfig
from .config.storage import StorageConfig
from .config.user_schema import User
from .repository.base import BaseRepository
from .repository.factory import RepositoryFactory
from .utils import JWTHandler
from .utils.jwt_token import JWTHandler
from .utils.time_helpers import cast_to_seconds


class JWTAuthBackend:
_instance = None
"""
A backend class for handling JWT-based authentication.
def __new__(cls, *args, **kwargs):
This class provides methods for creating, validating, and invalidating JWT tokens.
It supports configurable authentication settings, user schemas, and storage backends.
Attributes:
_config (AuthConfig): Configuration for authentication (e.g., secret key, algorithm).
_user_schema (Type[BaseModel]): Schema for validating user data.
_storage_config (StorageConfig): Configuration for the storage backend.
_cache (BaseRepository): The cache repository for storing and retrieving token data.
_jwt_handler (JWTHandler): Handler for encoding and decoding JWT tokens.
"""

_instance: Optional["JWTAuthBackend"] = None

def __new__(cls, *args, **kwargs) -> "JWTAuthBackend":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
Expand All @@ -36,6 +53,25 @@ def __init__(
self._initialized = True

async def authenticate(self, token: str) -> Optional[BaseModel]:
"""
Authenticate a user based on a provided JWT token.
Args:
token (str): The JWT token to authenticate.
Returns:
Optional[BaseModel]: The authenticated user model, or None if authentication fails.
Raises:
jwt.PyJWTError: If there is an issue with decoding the JWT token.
Exception: For any other unexpected errors.
Examples:
>>> backend = JWTAuthBackend()
>>> user = await backend.authenticate("some_jwt_token")
>>> if user:
>>> print(f"Authenticated user: {user}")
"""
try:
user = await self.get_current_user(token)
return user
Expand All @@ -47,32 +83,88 @@ async def authenticate(self, token: str) -> Optional[BaseModel]:
async def create_token(
self,
payload: Dict[str, Any],
expiration_seconds: Optional[Union[int, float, timedelta]] = None,
expiration: Optional[Union[int, float, timedelta]] = None,
) -> str:
if expiration_seconds is None:
"""
Create a JWT token with an optional expiration time.
Args:
payload (Dict[str, Any]): The payload data to encode into the JWT.
expiration (Optional[Union[int, float, datetime.timedelta]]): Expiration time in seconds or timedelta.
Returns:
str: The generated JWT token.
Raises:
Exception: If there is an issue storing the token in the cache.
Examples:
>>> backend = JWTAuthBackend()
>>> token = await backend.create_token({"user_id": 123}, expiration=3600)
>>> print(f"Generated token: {token}")
"""
if expiration is None:
expiration_candidates = ["expire", "expiration", "exp"]
for field in self.config.model_fields.keys():
if any(candidate in field for candidate in expiration_candidates):
expiration_seconds = getattr(self.config, field)
expiration = cast_to_seconds(getattr(self.config, field))
break
else:
expiration = (
int(expiration.total_seconds())
if isinstance(expiration, timedelta)
else int(expiration)
)

token = self.jwt_handler.encode(payload=payload, expiration=expiration_seconds)
token = self.jwt_handler.encode(payload=payload, expiration=expiration)

try:
await self.cache.set(
key=token,
value=payload,
expiration=expiration_seconds,
expiration=expiration,
)
except Exception as e:
raise Exception(f"Failed to store token in cache: {e}")

return token

async def invalidate_token(self, token: str) -> None:
"""
Invalidate a JWT token by removing it from the cache.
Args:
token (str): The JWT token to invalidate.
Returns:
None
Examples:
>>> backend = JWTAuthBackend()
>>> await backend.invalidate_token("some_jwt_token")
"""
await self.cache.delete(token)

async def get_current_user(self, token: str) -> Optional[BaseModel]:
"""
Retrieve the current user based on a JWT token.
Args:
token (str): The JWT token to decode and validate.
Returns:
Optional[BaseModel]: The validated user model, or None if validation fails.
Raises:
jwt.InvalidTokenError: If the token payload does not match the cached payload.
Exception: For any other unexpected errors.
Examples:
>>> backend = JWTAuthBackend()
>>> user = await backend.get_current_user("some_jwt_token")
>>> if user:
>>> print(f"Current user: {user}")
"""
token_payload = self.jwt_handler.decode(token)
try:
cached_payload = await self.cache.get(token)
Expand All @@ -97,46 +189,67 @@ async def get_current_user(self, token: str) -> Optional[BaseModel]:

@classmethod
def get_instance(cls) -> Optional["JWTAuthBackend"]:
"""
Get the singleton instance of the JWTAuthBackend class.
Returns:
Optional[JWTAuthBackend]: The singleton instance.
Examples:
>>> backend = JWTAuthBackend.get_instance()
>>> if backend:
>>> print("JWTAuthBackend instance exists.")
"""
return cls._instance

@property
def config(self) -> AuthConfig:
"""Get the current authentication configuration."""
return self._config

@config.setter
def config(self, value: AuthConfig):
def config(self, value: AuthConfig) -> None:
"""Set a new authentication configuration."""
self._config = value

@property
def user_schema(self) -> Type[BaseModel]:
"""Get the current user schema."""
return self._user_schema

@user_schema.setter
def user_schema(self, value: Type[BaseModel]):
def user_schema(self, value: Type[BaseModel]) -> None:
"""Set a new user schema."""
self._user_schema = value

@property
def storage_config(self) -> StorageConfig:
"""Get the current storage configuration."""
return self._storage_config

@storage_config.setter
def storage_config(self, value: StorageConfig):
def storage_config(self, value: StorageConfig) -> None:
"""Set a new storage configuration."""
self._storage_config = value

@property
def cache(self) -> BaseRepository:
"""Get the current cache repository."""
return self._cache

@cache.setter
def cache(self, value: BaseRepository):
def cache(self, value: BaseRepository) -> None:
"""Set a new cache repository."""
self._cache = value

@property
def jwt_handler(self) -> JWTHandler:
"""Get the current JWT handler."""
return self._jwt_handler

@jwt_handler.setter
def jwt_handler(self, value: JWTHandler):
def jwt_handler(self, value: JWTHandler) -> None:
"""Set a new JWT handler."""
self._jwt_handler = value


Expand Down
4 changes: 0 additions & 4 deletions fastapi_auth_jwt/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from .auth_token import AuthConfig
from .storage import StorageConfig
from .storage_type import StorageTypes
from .user_schema import User
92 changes: 88 additions & 4 deletions fastapi_auth_jwt/config/auth_token.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,55 @@
import json
from typing import Optional

from pydantic import BaseModel, Field, computed_field, field_validator


class AuthConfig(BaseModel):
"""
Configuration class for authentication settings.
This class is used to configure authentication-related settings, such as the secret key,
algorithm used for token encoding, and token expiration time in seconds. It also provides
validation for these settings and computed fields for convenience.
Attributes:
secret (str): The secret key used for signing tokens. Defaults to "default_secret".
algorithm (str): The algorithm used for encoding tokens. Defaults to "HS256".
expiration_seconds (int): The token expiration time in seconds. Defaults to 3600.
"""

secret: str = Field(default="default_secret")
algorithm: str = Field(default="HS256")
expiration_seconds: int = Field(default=3600)

@field_validator("expiration_seconds", mode="before")
def validate_expiration_seconds(cls, v):
def validate_expiration_seconds(cls, v) -> int:
"""
Validate and convert the `expiration_seconds` field to an integer.
This method ensures that the `expiration_seconds` field is a positive integer.
If a float or string is provided, it is converted to an integer. If the value
is less than 0, a `ValueError` is raised.
Args:
v (Union[int, float, str]): The value to validate and convert.
Returns:
int: The validated and converted value.
Raises:
ValueError: If `v` is less than 0.
Examples:
>>> AuthConfig.validate_expiration_seconds(3600)
3600
>>> AuthConfig.validate_expiration_seconds("7200")
7200
>>> AuthConfig.validate_expiration_seconds(-100)
ValueError: expiration_seconds must be greater than 0
"""
if isinstance(v, (int, float, str)):
v = int(v)

Expand All @@ -19,14 +59,58 @@ def validate_expiration_seconds(cls, v):
return v

@computed_field(return_type=int)
def expiration_minutes(self) -> int:
def expiration_minutes(self) -> Optional[int]:
"""
Compute the token expiration time in minutes.
This method calculates the token expiration time in minutes based on the
`expiration_seconds` attribute.
Returns:
Optional[int]: The token expiration time in minutes.
Examples:
>>> config = AuthConfig(expiration_seconds=3600)
>>> config.expiration_minutes
60
>>> config = AuthConfig(expiration_seconds=4500)
>>> config.expiration_minutes
75
"""
return self.expiration_seconds // 60 if self.expiration_seconds else None

def __repr__(self):
def __repr__(self) -> str:
"""
Return a string representation of the authentication configuration object.
The representation includes all configuration attributes formatted as a JSON string.
Returns:
str: A JSON-formatted string representation of the authentication configuration.
Examples:
>>> config = AuthConfig(secret="my_secret", algorithm="HS512", expiration_seconds=7200)
>>> repr(config)
'<AuthConfig: { "secret": "my_secret", "algorithm": "HS512", "expiration_seconds": 7200, "expiration_minutes": 120 }>'
"""
dict_repr = json.dumps(self.model_dump(), indent=2)
return f"<AuthConfig: {dict_repr}>"

def __str__(self):
def __str__(self) -> str:
"""
Return a string representation of the authentication configuration object.
This method calls `__repr__` to provide a consistent string representation.
Returns:
str: A JSON-formatted string representation of the authentication configuration.
Examples:
>>> config = AuthConfig(secret="my_secret", algorithm="HS512", expiration_seconds=7200)
>>> str(config)
'<AuthConfig: { "secret": "my_secret", "algorithm": "HS512", "expiration_seconds": 7200, "expiration_minutes": 120 }>'
"""
return self.__repr__()


Expand Down
Loading

0 comments on commit 4bb666f

Please sign in to comment.