Skip to content

Commit

Permalink
Added core functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
vladvildanov committed Dec 11, 2024
1 parent 6dae71b commit 32fc374
Show file tree
Hide file tree
Showing 8 changed files with 1,139 additions and 0 deletions.
Empty file added redis/auth/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions redis/auth/err.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Iterable


class RequestTokenErr(Exception):
"""
Represents an exception during token request.
"""
def __init__(self, *args):
super().__init__(*args)


class InvalidTokenSchemaErr(Exception):
"""
Represents an exception related to invalid token schema.
"""
def __init__(self, missing_fields: Iterable[str] = []):
super().__init__(
"Unexpected token schema. Following fields are missing: " + ", ".join(missing_fields)
)


class TokenRenewalErr(Exception):
"""
Represents an exception during token renewal process.
"""
def __init__(self, *args):
super().__init__(*args)
25 changes: 25 additions & 0 deletions redis/auth/idp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from redis.auth.token import TokenInterface

"""
This interface is the facade of an identity provider
"""


class IdentityProviderInterface(ABC):
"""
Receive a token from the identity provider. Receiving a token only works when being authenticated.
"""

@abstractmethod
def request_token(self, force_refresh=False) -> TokenInterface:
pass


class IdentityProviderConfigInterface(ABC):
"""
Configuration class that provides a configured identity provider.
"""
@abstractmethod
def get_provider(self) -> IdentityProviderInterface:
pass
121 changes: 121 additions & 0 deletions redis/auth/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from abc import ABC, abstractmethod

import jwt
from datetime import datetime, timezone

from redis.auth.err import InvalidTokenSchemaErr


class TokenInterface(ABC):
@abstractmethod
def is_expired(self) -> bool:
pass

@abstractmethod
def ttl(self) -> float:
pass

@abstractmethod
def try_get(self, key: str) -> str:
pass

@abstractmethod
def get_value(self) -> str:
pass

@abstractmethod
def get_expires_at_ms(self) -> float:
pass

@abstractmethod
def get_received_at_ms(self) -> float:
pass


class TokenResponse:
def __init__(self, token: TokenInterface):
self._token = token

def get_token(self) -> TokenInterface:
return self._token

def get_ttl_ms(self) -> float:
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()


class SimpleToken(TokenInterface):
def __init__(self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict) -> None:
self.value = value
self.expires_at = expires_at_ms
self.received_at = received_at_ms
self.claims = claims

def ttl(self) -> float:
if self.expires_at == -1:
return -1

return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)

def is_expired(self) -> bool:
if self.expires_at == -1:
return False

return self.ttl() <= 0

def try_get(self, key: str) -> str:
return self.claims.get(key)

def get_value(self) -> str:
return self.value

def get_expires_at_ms(self) -> float:
return self.expires_at

def get_received_at_ms(self) -> float:
return self.received_at


class JWToken(TokenInterface):

REQUIRED_FIELDS = {'exp'}

def __init__(self, token: str):
self._value = token
self._decoded = jwt.decode(
self._value,
options={"verify_signature": False},
algorithms=[jwt.get_unverified_header(self._value).get('alg')]
)
self._validate_token()

def is_expired(self) -> bool:
exp = self._decoded['exp']
if exp == -1:
return False

return self._decoded['exp'] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000

def ttl(self) -> float:
exp = self._decoded['exp']
if exp == -1:
return -1

return self._decoded['exp'] * 1000 - datetime.now(timezone.utc).timestamp() * 1000

def try_get(self, key: str) -> str:
return self._decoded.get(key)

def get_value(self) -> str:
return self._value

def get_expires_at_ms(self) -> float:
return float(self._decoded['exp'] * 1000)

def get_received_at_ms(self) -> float:
return datetime.now(timezone.utc).timestamp() * 1000

def _validate_token(self):
actual_fields = {x for x in self._decoded.keys()}

if len(self.REQUIRED_FIELDS - actual_fields) != 0:
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)
Loading

0 comments on commit 32fc374

Please sign in to comment.