-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6dae71b
commit 32fc374
Showing
8 changed files
with
1,139 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import Iterable | ||
|
||
|
||
class RequestTokenErr(Exception): | ||
""" | ||
Represents an exception during token request. | ||
""" | ||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
|
||
class InvalidTokenSchemaErr(Exception): | ||
""" | ||
Represents an exception related to invalid token schema. | ||
""" | ||
def __init__(self, missing_fields: Iterable[str] = []): | ||
super().__init__( | ||
"Unexpected token schema. Following fields are missing: " + ", ".join(missing_fields) | ||
) | ||
|
||
|
||
class TokenRenewalErr(Exception): | ||
""" | ||
Represents an exception during token renewal process. | ||
""" | ||
def __init__(self, *args): | ||
super().__init__(*args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
from redis.auth.token import TokenInterface | ||
|
||
""" | ||
This interface is the facade of an identity provider | ||
""" | ||
|
||
|
||
class IdentityProviderInterface(ABC): | ||
""" | ||
Receive a token from the identity provider. Receiving a token only works when being authenticated. | ||
""" | ||
|
||
@abstractmethod | ||
def request_token(self, force_refresh=False) -> TokenInterface: | ||
pass | ||
|
||
|
||
class IdentityProviderConfigInterface(ABC): | ||
""" | ||
Configuration class that provides a configured identity provider. | ||
""" | ||
@abstractmethod | ||
def get_provider(self) -> IdentityProviderInterface: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import jwt | ||
from datetime import datetime, timezone | ||
|
||
from redis.auth.err import InvalidTokenSchemaErr | ||
|
||
|
||
class TokenInterface(ABC): | ||
@abstractmethod | ||
def is_expired(self) -> bool: | ||
pass | ||
|
||
@abstractmethod | ||
def ttl(self) -> float: | ||
pass | ||
|
||
@abstractmethod | ||
def try_get(self, key: str) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def get_value(self) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def get_expires_at_ms(self) -> float: | ||
pass | ||
|
||
@abstractmethod | ||
def get_received_at_ms(self) -> float: | ||
pass | ||
|
||
|
||
class TokenResponse: | ||
def __init__(self, token: TokenInterface): | ||
self._token = token | ||
|
||
def get_token(self) -> TokenInterface: | ||
return self._token | ||
|
||
def get_ttl_ms(self) -> float: | ||
return self._token.get_expires_at_ms() - self._token.get_received_at_ms() | ||
|
||
|
||
class SimpleToken(TokenInterface): | ||
def __init__(self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict) -> None: | ||
self.value = value | ||
self.expires_at = expires_at_ms | ||
self.received_at = received_at_ms | ||
self.claims = claims | ||
|
||
def ttl(self) -> float: | ||
if self.expires_at == -1: | ||
return -1 | ||
|
||
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000) | ||
|
||
def is_expired(self) -> bool: | ||
if self.expires_at == -1: | ||
return False | ||
|
||
return self.ttl() <= 0 | ||
|
||
def try_get(self, key: str) -> str: | ||
return self.claims.get(key) | ||
|
||
def get_value(self) -> str: | ||
return self.value | ||
|
||
def get_expires_at_ms(self) -> float: | ||
return self.expires_at | ||
|
||
def get_received_at_ms(self) -> float: | ||
return self.received_at | ||
|
||
|
||
class JWToken(TokenInterface): | ||
|
||
REQUIRED_FIELDS = {'exp'} | ||
|
||
def __init__(self, token: str): | ||
self._value = token | ||
self._decoded = jwt.decode( | ||
self._value, | ||
options={"verify_signature": False}, | ||
algorithms=[jwt.get_unverified_header(self._value).get('alg')] | ||
) | ||
self._validate_token() | ||
|
||
def is_expired(self) -> bool: | ||
exp = self._decoded['exp'] | ||
if exp == -1: | ||
return False | ||
|
||
return self._decoded['exp'] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000 | ||
|
||
def ttl(self) -> float: | ||
exp = self._decoded['exp'] | ||
if exp == -1: | ||
return -1 | ||
|
||
return self._decoded['exp'] * 1000 - datetime.now(timezone.utc).timestamp() * 1000 | ||
|
||
def try_get(self, key: str) -> str: | ||
return self._decoded.get(key) | ||
|
||
def get_value(self) -> str: | ||
return self._value | ||
|
||
def get_expires_at_ms(self) -> float: | ||
return float(self._decoded['exp'] * 1000) | ||
|
||
def get_received_at_ms(self) -> float: | ||
return datetime.now(timezone.utc).timestamp() * 1000 | ||
|
||
def _validate_token(self): | ||
actual_fields = {x for x in self._decoded.keys()} | ||
|
||
if len(self.REQUIRED_FIELDS - actual_fields) != 0: | ||
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields) |
Oops, something went wrong.