From 32fc3740e5dfb3e8ad64002088db8e5a27dff669 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 11 Dec 2024 14:42:58 +0200 Subject: [PATCH] Added core functionality --- redis/auth/__init__.py | 0 redis/auth/err.py | 27 ++ redis/auth/idp.py | 25 ++ redis/auth/token.py | 121 ++++++ redis/auth/token_manager.py | 342 ++++++++++++++++ tests/test_auth/__init__.py | 0 tests/test_auth/test_token.py | 73 ++++ tests/test_auth/test_token_manager.py | 551 ++++++++++++++++++++++++++ 8 files changed, 1139 insertions(+) create mode 100644 redis/auth/__init__.py create mode 100644 redis/auth/err.py create mode 100644 redis/auth/idp.py create mode 100644 redis/auth/token.py create mode 100644 redis/auth/token_manager.py create mode 100644 tests/test_auth/__init__.py create mode 100644 tests/test_auth/test_token.py create mode 100644 tests/test_auth/test_token_manager.py diff --git a/redis/auth/__init__.py b/redis/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/auth/err.py b/redis/auth/err.py new file mode 100644 index 0000000000..9c6a95bb05 --- /dev/null +++ b/redis/auth/err.py @@ -0,0 +1,27 @@ +from typing import Iterable + + +class RequestTokenErr(Exception): + """ + Represents an exception during token request. + """ + def __init__(self, *args): + super().__init__(*args) + + +class InvalidTokenSchemaErr(Exception): + """ + Represents an exception related to invalid token schema. + """ + def __init__(self, missing_fields: Iterable[str] = []): + super().__init__( + "Unexpected token schema. Following fields are missing: " + ", ".join(missing_fields) + ) + + +class TokenRenewalErr(Exception): + """ + Represents an exception during token renewal process. + """ + def __init__(self, *args): + super().__init__(*args) \ No newline at end of file diff --git a/redis/auth/idp.py b/redis/auth/idp.py new file mode 100644 index 0000000000..955cad76fb --- /dev/null +++ b/redis/auth/idp.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from redis.auth.token import TokenInterface + +""" +This interface is the facade of an identity provider +""" + + +class IdentityProviderInterface(ABC): + """ + Receive a token from the identity provider. Receiving a token only works when being authenticated. + """ + + @abstractmethod + def request_token(self, force_refresh=False) -> TokenInterface: + pass + + +class IdentityProviderConfigInterface(ABC): + """ + Configuration class that provides a configured identity provider. + """ + @abstractmethod + def get_provider(self) -> IdentityProviderInterface: + pass diff --git a/redis/auth/token.py b/redis/auth/token.py new file mode 100644 index 0000000000..be999c67ec --- /dev/null +++ b/redis/auth/token.py @@ -0,0 +1,121 @@ +from abc import ABC, abstractmethod + +import jwt +from datetime import datetime, timezone + +from redis.auth.err import InvalidTokenSchemaErr + + +class TokenInterface(ABC): + @abstractmethod + def is_expired(self) -> bool: + pass + + @abstractmethod + def ttl(self) -> float: + pass + + @abstractmethod + def try_get(self, key: str) -> str: + pass + + @abstractmethod + def get_value(self) -> str: + pass + + @abstractmethod + def get_expires_at_ms(self) -> float: + pass + + @abstractmethod + def get_received_at_ms(self) -> float: + pass + + +class TokenResponse: + def __init__(self, token: TokenInterface): + self._token = token + + def get_token(self) -> TokenInterface: + return self._token + + def get_ttl_ms(self) -> float: + return self._token.get_expires_at_ms() - self._token.get_received_at_ms() + + +class SimpleToken(TokenInterface): + def __init__(self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict) -> None: + self.value = value + self.expires_at = expires_at_ms + self.received_at = received_at_ms + self.claims = claims + + def ttl(self) -> float: + if self.expires_at == -1: + return -1 + + return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000) + + def is_expired(self) -> bool: + if self.expires_at == -1: + return False + + return self.ttl() <= 0 + + def try_get(self, key: str) -> str: + return self.claims.get(key) + + def get_value(self) -> str: + return self.value + + def get_expires_at_ms(self) -> float: + return self.expires_at + + def get_received_at_ms(self) -> float: + return self.received_at + + +class JWToken(TokenInterface): + + REQUIRED_FIELDS = {'exp'} + + def __init__(self, token: str): + self._value = token + self._decoded = jwt.decode( + self._value, + options={"verify_signature": False}, + algorithms=[jwt.get_unverified_header(self._value).get('alg')] + ) + self._validate_token() + + def is_expired(self) -> bool: + exp = self._decoded['exp'] + if exp == -1: + return False + + return self._decoded['exp'] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000 + + def ttl(self) -> float: + exp = self._decoded['exp'] + if exp == -1: + return -1 + + return self._decoded['exp'] * 1000 - datetime.now(timezone.utc).timestamp() * 1000 + + def try_get(self, key: str) -> str: + return self._decoded.get(key) + + def get_value(self) -> str: + return self._value + + def get_expires_at_ms(self) -> float: + return float(self._decoded['exp'] * 1000) + + def get_received_at_ms(self) -> float: + return datetime.now(timezone.utc).timestamp() * 1000 + + def _validate_token(self): + actual_fields = {x for x in self._decoded.keys()} + + if len(self.REQUIRED_FIELDS - actual_fields) != 0: + raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields) \ No newline at end of file diff --git a/redis/auth/token_manager.py b/redis/auth/token_manager.py new file mode 100644 index 0000000000..2720fa3e8b --- /dev/null +++ b/redis/auth/token_manager.py @@ -0,0 +1,342 @@ +import threading +from datetime import datetime, timezone +from time import sleep +from typing import Callable, Any, Awaitable, Coroutine, Union + +import asyncio + +from redis.auth.err import RequestTokenErr, TokenRenewalErr +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token import TokenResponse + + +class CredentialsListener: + """ + Listeners that will be notified on events related to credentials. + Accepts callbacks and awaitable callbacks. + """ + def __init__(self): + self._on_next = None + self._on_error = None + + @property + def on_next(self) -> Union[Callable[[Any], None], Awaitable]: + return self._on_next + + @on_next.setter + def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None: + self._on_next = callback + + @property + def on_error(self) -> Union[Callable[[Exception], None], Awaitable]: + return self._on_error + + @on_error.setter + def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None: + self._on_error = callback + + +class RetryPolicy: + def __init__(self, max_attempts: int, delay_in_ms: float): + self.max_attempts = max_attempts + self.delay_in_ms = delay_in_ms + + def get_max_attempts(self) -> int: + """ + Retry attempts before exception will be thrown. + + :return: int + """ + return self.max_attempts + + def get_delay_in_ms(self) -> float: + """ + Delay between retries in seconds. + + :return: int + """ + return self.delay_in_ms + + +class TokenManagerConfig: + def __init__( + self, + expiration_refresh_ratio: float, + lower_refresh_bound_millis: int, + token_request_execution_timeout_in_ms: int, + retry_policy: RetryPolicy, + ): + self._expiration_refresh_ratio = expiration_refresh_ratio + self._lower_refresh_bound_millis = lower_refresh_bound_millis + self._token_request_execution_timeout_in_ms = token_request_execution_timeout_in_ms + self._retry_policy = retry_policy + + def get_expiration_refresh_ratio(self) -> float: + """ + Represents the ratio of a token's lifetime at which a refresh should be triggered. + For example, a value of 0.75 means the token should be refreshed when 75% of its + lifetime has elapsed (or when 25% of its lifetime remains). + + :return: float + """ + + return self._expiration_refresh_ratio + + def get_lower_refresh_bound_millis(self) -> int: + """ + Represents the minimum time in milliseconds before token expiration to trigger a refresh, in milliseconds. + This value sets a fixed lower bound for when a token refresh should occur, regardless + of the token's total lifetime. + If set to 0 there will be no lower bound and the refresh will be triggered + based on the expirationRefreshRatio only. + + :return: int + """ + return self._lower_refresh_bound_millis + + def get_token_request_execution_timeout_in_ms(self) -> int: + """ + Represents the maximum time in milliseconds to wait for a token request to complete. + + :return: int + """ + return self._token_request_execution_timeout_in_ms + + def get_retry_policy(self) -> RetryPolicy: + """ + Represents the retry policy for token requests. + + :return: RetryPolicy + """ + return self._retry_policy + + +class TokenManager: + def __init__(self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig): + self._idp = identity_provider + self._config = config + self._next_timer = None + self._listener = None + self._init_timer = None + self._retries = 0 + + def __del__(self): + self.stop() + + def start( + self, + listener: CredentialsListener, + skip_initial: bool = False, + ) -> Callable[[], None]: + self._listener = listener + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop,), + daemon=True + ) + thread.start() + + # Event to block for initial execution. + init_event = asyncio.Event() + self._init_timer = loop.call_later( + 0, + self._renew_token, + skip_initial, + init_event + ) + + # Blocks in thread-safe maner. + asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result() + return self.stop + + async def start_async( + self, + listener: CredentialsListener, + block_for_initial: bool = False, + initial_delay_in_ms: float = 0, + skip_initial: bool = False, + ) -> Callable[[], Coroutine[Any, Any, None]]: + self._listener = listener + + loop = asyncio.get_running_loop() + init_event = asyncio.Event() + + # Wraps the async callback with async wrapper to schedule with loop.call_later() + wrapped = _async_to_sync_wrapper(loop, self._renew_token_async, skip_initial, init_event) + self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped) + + if block_for_initial: + await init_event.wait() + + return self.stop_async + + def stop(self): + if self._next_timer is not None: + self._next_timer.cancel() + + async def stop_async(self): + return self.stop() + + def acquire_token(self, force_refresh=False) -> TokenResponse: + try: + token = self._idp.request_token(force_refresh) + except RequestTokenErr as e: + if self._retries < self._config.get_retry_policy().get_max_attempts(): + self._retries += 1 + sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000) + return self.acquire_token(force_refresh) + else: + raise e + + self._retries = 0 + return TokenResponse(token) + + async def acquire_token_async(self, force_refresh=False) -> TokenResponse: + try: + token = self._idp.request_token(force_refresh) + except RequestTokenErr as e: + if self._retries < self._config.get_retry_policy().get_max_attempts(): + self._retries += 1 + await asyncio.sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000) + return await self.acquire_token_async(force_refresh) + else: + raise e + + self._retries = 0 + return TokenResponse(token) + + def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float: + delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date) + delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date) + delay = min(delay_for_ratio_refresh, delay_for_lower_refresh) + + return 0 if delay < 0 else delay / 1000 + + def _delay_for_lower_refresh(self, expire_date: float): + return (expire_date - self._config.get_lower_refresh_bound_millis() - + (datetime.now(timezone.utc).timestamp() * 1000)) + + def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float): + token_ttl = expire_date - issue_date + refresh_before = token_ttl - (token_ttl * self._config.get_expiration_refresh_ratio()) + + return expire_date - refresh_before - (datetime.now(timezone.utc).timestamp() * 1000) + + def _renew_token(self, skip_initial: bool = False, init_event: asyncio.Event = None): + """ + Task to renew token from identity provider. + Schedules renewal tasks based on token TTL. + """ + + try: + token_res = self.acquire_token(force_refresh=True) + delay = self._calculate_renewal_delay( + token_res.get_token().get_expires_at_ms(), + token_res.get_token().get_received_at_ms() + ) + + if token_res.get_token().is_expired(): + raise TokenRenewalErr("Requested token is expired") + + if self._listener.on_next is None: + return + + if not skip_initial: + try: + self._listener.on_next(token_res.get_token()) + except Exception as e: + raise TokenRenewalErr(e) + + if delay <= 0: + return + + loop = asyncio.get_running_loop() + self._next_timer = loop.call_later(delay, self._renew_token) + return token_res + except Exception as e: + if self._listener.on_error is None: + raise e + + self._listener.on_error(e) + finally: + if init_event: + init_event.set() + + async def _renew_token_async(self, skip_initial: bool = False, init_event: asyncio.Event = None): + """ + Async task to renew tokens from identity provider. + Schedules renewal tasks based on token TTL. + """ + + try: + token_res = await self.acquire_token_async(force_refresh=True) + delay = self._calculate_renewal_delay( + token_res.get_token().get_expires_at_ms(), + token_res.get_token().get_received_at_ms() + ) + + if token_res.get_token().is_expired(): + raise TokenRenewalErr("Requested token is expired") + + if self._listener.on_next is None: + return + + if not skip_initial: + try: + await self._listener.on_next(token_res.get_token()) + except Exception as e: + raise TokenRenewalErr(e) + + if delay <= 0: + return + + loop = asyncio.get_running_loop() + wrapped = _async_to_sync_wrapper(loop, self._renew_token_async) + loop.call_later(delay, wrapped) + except Exception as e: + if init_event: + init_event.set() + + if self._listener.on_error is None: + raise e + + await self._listener.on_error(e) + finally: + if init_event: + init_event.set() + + +def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): + """ + Wraps an asynchronous function so it can be used with loop.call_later. + + :param loop: The event loop in which the coroutine will be executed. + :param coro_func: The coroutine function to wrap. + :param args: Positional arguments to pass to the coroutine function. + :param kwargs: Keyword arguments to pass to the coroutine function. + :return: A regular function suitable for loop.call_later. + """ + + def wrapped(): + # Schedule the coroutine in the event loop + asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) + + return wrapped + + +def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop): + """ + Starts event loop in a thread. + Used to be able to schedule tasks using loop.call_later. + + :param event_loop: + :return: + """ + asyncio.set_event_loop(event_loop) + event_loop.run_forever() \ No newline at end of file diff --git a/tests/test_auth/__init__.py b/tests/test_auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_auth/test_token.py b/tests/test_auth/test_token.py new file mode 100644 index 0000000000..ae397e6274 --- /dev/null +++ b/tests/test_auth/test_token.py @@ -0,0 +1,73 @@ +from datetime import datetime, timezone + +import jwt +import pytest + +from redis.auth.err import InvalidTokenSchemaErr +from redis.auth.token import SimpleToken, JWToken + + +class TestToken: + + def test_simple_token(self): + token = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"key": "value"} + ) + + assert token.ttl() == pytest.approx(1000, 10) + assert token.is_expired() is False + assert token.try_get('key') == "value" + assert token.get_value() == "value" + assert token.get_expires_at_ms() == pytest.approx((datetime.now(timezone.utc).timestamp() * 1000) + 100, 10) + assert token.get_received_at_ms() == pytest.approx((datetime.now(timezone.utc).timestamp() * 1000), 10) + + token = SimpleToken( + 'value', + -1, + (datetime.now(timezone.utc).timestamp() * 1000), + {"key": "value"} + ) + + assert token.ttl() == -1 + assert token.is_expired() is False + assert token.get_expires_at_ms() == -1 + + def test_jwt_token(self): + token = { + 'exp': datetime.now(timezone.utc).timestamp() + 100, + 'iat': datetime.now(timezone.utc).timestamp(), + 'key': "value" + } + encoded = jwt.encode(token, "secret", algorithm='HS256') + jwt_token = JWToken(encoded) + + assert jwt_token.ttl() == pytest.approx(100000, 10) + assert jwt_token.is_expired() is False + assert jwt_token.try_get('key') == "value" + assert jwt_token.get_value() == encoded + assert jwt_token.get_expires_at_ms() == pytest.approx( + (datetime.now(timezone.utc).timestamp() * 1000) + 100000, 10 + ) + assert jwt_token.get_received_at_ms() == pytest.approx((datetime.now(timezone.utc).timestamp() * 1000), 10) + + token = { + 'exp': -1, + 'iat': datetime.now(timezone.utc).timestamp(), + 'key': "value" + } + encoded = jwt.encode(token, "secret", algorithm='HS256') + jwt_token = JWToken(encoded) + + assert jwt_token.ttl() == -1 + assert jwt_token.is_expired() is False + assert jwt_token.get_expires_at_ms() == -1000 + + with pytest.raises(InvalidTokenSchemaErr): + token = { + 'key': "value" + } + encoded = jwt.encode(token, "secret", algorithm='HS256') + JWToken(encoded) \ No newline at end of file diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py new file mode 100644 index 0000000000..b9d973a9a7 --- /dev/null +++ b/tests/test_auth/test_token_manager.py @@ -0,0 +1,551 @@ +from datetime import datetime, timezone +from time import sleep +from unittest.mock import Mock + +import asyncio +import pytest + +from redis.auth.err import RequestTokenErr, TokenRenewalErr +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token_manager import ( + CredentialsListener, + TokenManagerConfig, + RetryPolicy, + TokenManager +) +from redis.auth.token import SimpleToken + + +class TestTokenManager: + @pytest.mark.parametrize( + "exp_refresh_ratio,tokens_refreshed", + [ + (0.9, 2), + (0.28, 4), + ], + ids=[ + "Refresh ratio = 0.9, 2 tokens in 0,1 second", + "Refresh ratio = 0.28, 4 tokens in 0,1 second", + ] + ) + def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'}), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 130, + (datetime.now(timezone.utc).timestamp() * 1000) + 30, + {"oid": 'test'}), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 160, + (datetime.now(timezone.utc).timestamp() * 1000) + 60, + {"oid": 'test'}), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 190, + (datetime.now(timezone.utc).timestamp() * 1000) + 90, + {"oid": 'test'}), + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(exp_refresh_ratio, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + sleep(0.1) + + assert len(tokens) == tokens_refreshed + + @pytest.mark.parametrize( + "exp_refresh_ratio,tokens_refreshed", + [ + (0.9, 2), + (0.28, 4), + ], + ids=[ + "Refresh ratio = 0.9, 2 tokens in 0,1 second", + "Refresh ratio = 0.28, 4 tokens in 0,1 second", + ] + ) + @pytest.mark.asyncio + async def test_async_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'}), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 130, + (datetime.now(timezone.utc).timestamp() * 1000) + 30, + {"oid": 'test'}), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 160, + (datetime.now(timezone.utc).timestamp() * 1000) + 60, + {"oid": 'test'}), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 190, + (datetime.now(timezone.utc).timestamp() * 1000) + 90, + {"oid": 'test'}), + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(exp_refresh_ratio, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + await asyncio.sleep(0.1) + + assert len(tokens) == tokens_refreshed + + @pytest.mark.parametrize( + "block_for_initial,tokens_acquired", + [ + (True, 1), + (False, 0), + ], + ids=[ + "Block for initial, callback will triggered once", + "Non blocked, callback wont be triggered", + ] + ) + @pytest.mark.asyncio + async def test_async_request_token_blocking_behaviour(self, block_for_initial, tokens_acquired): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ) + + async def on_next(token): + nonlocal tokens + sleep(0.1) + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=block_for_initial) + + assert len(tokens) == tokens_acquired + + def test_token_renewal_with_skip_initial(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 120, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener, skip_initial=True) + # Should be less than a 0.1, or it will be flacky due to additional token renewal. + sleep(0.2) + + assert len(tokens) == 2 + + @pytest.mark.asyncio + async def test_async_token_renewal_with_skip_initial(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 120, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, skip_initial=True) + # Should be less than a 0.1, or it will be flacky due to additional token renewal. + await asyncio.sleep(0.2) + + assert len(tokens) == 2 + + def test_success_token_renewal_with_retry(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ) + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + # Should be less than a 0.1, or it will be flacky due to additional token renewal. + sleep(0.08) + + assert mock_provider.request_token.call_count in {3, 4} + assert len(tokens) == 1 + + @pytest.mark.asyncio + async def test_async_success_token_renewal_with_retry(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ), + SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ) + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = None + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + # Should be less than a 0.1, or it will be flacky due to additional token renewal. + await asyncio.sleep(0.08) + + assert mock_provider.request_token.call_count in {3, 4} + assert len(tokens) == 1 + + def test_no_token_renewal_on_process_complete(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ) + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(0.9, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + sleep(0.2) + + assert len(tokens) == 1 + + @pytest.mark.asyncio + async def test_async_no_token_renewal_on_process_complete(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ) + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(0.9, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + await asyncio.sleep(0.2) + + assert len(tokens) == 1 + + def test_failed_token_renewal_with_retry(self): + tokens = [] + exceptions = [] + + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + def on_error(exception): + nonlocal exceptions + exceptions.append(exception) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + sleep(0.1) + + assert mock_provider.request_token.call_count == 4 + assert len(tokens) == 0 + assert len(exceptions) == 1 + + @pytest.mark.asyncio + async def test_async_failed_token_renewal_with_retry(self): + tokens = [] + exceptions = [] + + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + async def on_error(exception): + nonlocal exceptions + exceptions.append(exception) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + sleep(0.1) + + assert mock_provider.request_token.call_count == 4 + assert len(tokens) == 0 + assert len(exceptions) == 1 + + def test_failed_renewal_on_expired_token(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) - 100, + (datetime.now(timezone.utc).timestamp() * 1000) - 1000, + {"oid": 'test'} + ) + + def on_error(error: TokenRenewalErr): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Requested token is expired" + + @pytest.mark.asyncio + async def test_async_failed_renewal_on_expired_token(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) - 100, + (datetime.now(timezone.utc).timestamp() * 1000) - 1000, + {"oid": 'test'} + ) + + async def on_error(error: TokenRenewalErr): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Requested token is expired" + + def test_failed_renewal_on_callback_error(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ) + + def on_next(token): + raise Exception("Some exception") + + def on_error(error): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Some exception" + + @pytest.mark.asyncio + async def test_async_failed_renewal_on_callback_error(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + 'value', + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": 'test'} + ) + + async def on_next(token): + raise Exception("Some exception") + + async def on_error(error): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Some exception" \ No newline at end of file