diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 64d31969bd..5ca6bf5a09 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -103,7 +103,7 @@ runs: if (( $REDIS_MAJOR_VERSION < 7 )) && [ "$protocol" == "3" ]; then echo "Skipping module tests: Modules doesn't support RESP3 for Redis versions < 7" - invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod" + invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod and not cp_integration" else invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" fi diff --git a/dev_requirements.txt b/dev_requirements.txt index adfa99e80c..945afc35dc 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,3 +16,4 @@ uvloop vulture>=2.3.0 wheel>=0.30.0 numpy>=1.24.0 +redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main diff --git a/pytest.ini b/pytest.ini index bbb8d420c4..68fee2b603 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,6 +10,7 @@ markers = asyncio: marker for async tests replica: replica tests experimental: run only experimental tests + cp_integration: credential provider integration tests asyncio_mode = auto timeout = 30 filterwarnings = diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9508849703..9478d539d7 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -53,6 +53,13 @@ list_or_args, ) from redis.credentials import CredentialProvider +from redis.event import ( + AfterPooledConnectionsInstantiationEvent, + AfterPubSubConnectionInstantiationEvent, + AfterSingleConnectionInstantiationEvent, + ClientType, + EventDispatcher, +) from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -233,6 +240,7 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + event_dispatcher: Optional[EventDispatcher] = None, ): """ Initialize a new Redis client. @@ -242,6 +250,10 @@ def __init__( To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ kwargs: Dict[str, Any] + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher # auto_close_connection_pool only has an effect if connection_pool is # None. It is assumed that if connection_pool is not None, the user # wants to manage the connection pool themselves. @@ -320,9 +332,19 @@ def __init__( # This arg only used if no pool is passed in self.auto_close_connection_pool = auto_close_connection_pool connection_pool = ConnectionPool(**kwargs) + self._event_dispatcher.dispatch( + AfterPooledConnectionsInstantiationEvent( + [connection_pool], ClientType.ASYNC, credential_provider + ) + ) else: # If a pool is passed in, do not close it self.auto_close_connection_pool = False + self._event_dispatcher.dispatch( + AfterPooledConnectionsInstantiationEvent( + [connection_pool], ClientType.ASYNC, credential_provider + ) + ) self.connection_pool = connection_pool self.single_connection_client = single_connection_client @@ -354,6 +376,12 @@ async def initialize(self: _RedisT) -> _RedisT: async with self._single_conn_lock: if self.connection is None: self.connection = await self.connection_pool.get_connection("_") + + self._event_dispatcher.dispatch( + AfterSingleConnectionInstantiationEvent( + self.connection, ClientType.ASYNC, self._single_conn_lock + ) + ) return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -521,7 +549,9 @@ def pubsub(self, **kwargs) -> "PubSub": subscribe to channels and listen for messages that get published to them. """ - return PubSub(self.connection_pool, **kwargs) + return PubSub( + self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs + ) def monitor(self) -> "Monitor": return Monitor(self.connection_pool) @@ -759,7 +789,12 @@ def __init__( ignore_subscribe_messages: bool = False, encoder=None, push_handler_func: Optional[Callable] = None, + event_dispatcher: Optional["EventDispatcher"] = None, ): + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher self.connection_pool = connection_pool self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages @@ -876,6 +911,12 @@ async def connect(self): if self.push_handler_func is not None and not HIREDIS_AVAILABLE: self.connection._parser.set_pubsub_push_handler(self.push_handler_func) + self._event_dispatcher.dispatch( + AfterPubSubConnectionInstantiationEvent( + self.connection, self.connection_pool, ClientType.ASYNC, self._lock + ) + ) + async def _disconnect_raise_connect(self, conn, error): """ Close the connection and raise an exception diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4e82e5448f..408fa19363 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -29,6 +29,7 @@ from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry +from redis.auth.token import TokenInterface from redis.backoff import default_backoff from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( @@ -45,6 +46,7 @@ from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.credentials import CredentialProvider +from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher from redis.exceptions import ( AskError, BusyLoadingError, @@ -57,6 +59,7 @@ MaxConnectionsError, MovedError, RedisClusterException, + RedisError, ResponseError, SlotNotCoveredError, TimeoutError, @@ -270,6 +273,7 @@ def __init__( ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + event_dispatcher: Optional[EventDispatcher] = None, ) -> None: if db: raise RedisClusterException( @@ -366,11 +370,17 @@ def __init__( if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + self.nodes_manager = NodesManager( startup_nodes, require_full_coverage, kwargs, address_remap=address_remap, + event_dispatcher=self._event_dispatcher, ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas @@ -929,6 +939,8 @@ class ClusterNode: __slots__ = ( "_connections", "_free", + "_lock", + "_event_dispatcher", "connection_class", "connection_kwargs", "host", @@ -966,6 +978,9 @@ def __init__( self._connections: List[Connection] = [] self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) + if self._event_dispatcher is None: + self._event_dispatcher = EventDispatcher() def __repr__(self) -> str: return ( @@ -1082,10 +1097,38 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: return ret + async def re_auth_callback(self, token: TokenInterface): + tmp_queue = collections.deque() + while self._free: + conn = self._free.popleft() + await conn.retry.call_with_retry( + lambda: conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ), + lambda error: self._mock(error), + ) + await conn.retry.call_with_retry( + lambda: conn.read_response(), lambda error: self._mock(error) + ) + tmp_queue.append(conn) + + while tmp_queue: + conn = tmp_queue.popleft() + self._free.append(conn) + + async def _mock(self, error: RedisError): + """ + Dummy functions, needs to be passed as error callback to retry object. + :param error: + :return: + """ + pass + class NodesManager: __slots__ = ( "_moved_exception", + "_event_dispatcher", "connection_kwargs", "default_node", "nodes_cache", @@ -1102,6 +1145,7 @@ def __init__( require_full_coverage: bool, connection_kwargs: Dict[str, Any], address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + event_dispatcher: Optional[EventDispatcher] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage @@ -1113,6 +1157,10 @@ def __init__( self.slots_cache: Dict[int, List["ClusterNode"]] = {} self.read_load_balancer = LoadBalancer() self._moved_exception: MovedError = None + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher def get_node( self, @@ -1230,6 +1278,12 @@ async def initialize(self) -> None: try: # Make sure cluster mode is enabled on this node try: + self._event_dispatcher.dispatch( + AfterAsyncClusterInstantiationEvent( + self.nodes_cache, + self.connection_kwargs.get("credential_provider", None), + ) + ) cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") except ResponseError: raise RedisClusterException( diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 8c3123ac04..4a743ff374 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -27,6 +27,8 @@ ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse +from ..auth.token import TokenInterface +from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher from ..utils import format_error_message # the functionality is available in 3.11.x but has a major issue before @@ -148,6 +150,7 @@ def __init__( encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + event_dispatcher: Optional[EventDispatcher] = None, ): if (username or password) and credential_provider is not None: raise DataError( @@ -156,6 +159,10 @@ def __init__( "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher self.db = db self.client_name = client_name self.lib_name = lib_name @@ -195,6 +202,8 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 + self._re_auth_token: Optional[TokenInterface] = None + try: p = int(protocol) except TypeError: @@ -327,6 +336,9 @@ def _host_error(self) -> str: def _error_message(self, exception: BaseException) -> str: return format_error_message(self._host_error(), exception) + def get_protocol(self): + return self.protocol + async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) @@ -339,7 +351,8 @@ async def on_connect(self) -> None: self.credential_provider or UsernamePasswordCredentialProvider(self.username, self.password) ) - auth_args = cred_provider.get_credentials() + auth_args = await cred_provider.get_credentials_async() + # if resp version is specified and we have auth args, # we need to send them via HELLO if auth_args and self.protocol not in [2, "2"]: @@ -661,6 +674,19 @@ async def process_invalidation_messages(self): while not self._socket_is_empty(): await self.read_response(push_request=True) + def set_re_auth_token(self, token: TokenInterface): + self._re_auth_token = token + + async def re_auth(self): + if self._re_auth_token is not None: + await self.send_command( + "AUTH", + self._re_auth_token.try_get("oid"), + self._re_auth_token.get_value(), + ) + await self.read_response() + self._re_auth_token = None + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -1039,6 +1065,10 @@ def __init__( self._available_connections: List[AbstractConnection] = [] self._in_use_connections: Set[AbstractConnection] = set() self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) + self._lock = asyncio.Lock() + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) + if self._event_dispatcher is None: + self._event_dispatcher = EventDispatcher() def __repr__(self): return ( @@ -1058,13 +1088,14 @@ def can_get_connection(self) -> bool: ) async def get_connection(self, command_name, *keys, **options): - """Get a connected connection from the pool""" - connection = self.get_available_connection() - try: - await self.ensure_connection(connection) - except BaseException: - await self.release(connection) - raise + async with self._lock: + """Get a connected connection from the pool""" + connection = self.get_available_connection() + try: + await self.ensure_connection(connection) + except BaseException: + await self.release(connection) + raise return connection @@ -1114,6 +1145,9 @@ async def release(self, connection: AbstractConnection): # not doing so is an error that will cause an exception here. self._in_use_connections.remove(connection) self._available_connections.append(connection) + await self._event_dispatcher.dispatch_async( + AsyncAfterConnectionReleasedEvent(connection) + ) async def disconnect(self, inuse_connections: bool = True): """ @@ -1147,6 +1181,29 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry + async def re_auth_callback(self, token: TokenInterface): + async with self._lock: + for conn in self._available_connections: + await conn.retry.call_with_retry( + lambda: conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ), + lambda error: self._mock(error), + ) + await conn.retry.call_with_retry( + lambda: conn.read_response(), lambda error: self._mock(error) + ) + for conn in self._in_use_connections: + conn.set_re_auth_token(token) + + async def _mock(self, error: RedisError): + """ + Dummy functions, needs to be passed as error callback to retry object. + :param error: + :return: + """ + pass + class BlockingConnectionPool(ConnectionPool): """ diff --git a/redis/auth/__init__.py b/redis/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/auth/err.py b/redis/auth/err.py new file mode 100644 index 0000000000..743dab18fe --- /dev/null +++ b/redis/auth/err.py @@ -0,0 +1,31 @@ +from typing import Iterable + + +class RequestTokenErr(Exception): + """ + Represents an exception during token request. + """ + + def __init__(self, *args): + super().__init__(*args) + + +class InvalidTokenSchemaErr(Exception): + """ + Represents an exception related to invalid token schema. + """ + + def __init__(self, missing_fields: Iterable[str] = []): + super().__init__( + "Unexpected token schema. Following fields are missing: " + + ", ".join(missing_fields) + ) + + +class TokenRenewalErr(Exception): + """ + Represents an exception during token renewal process. + """ + + def __init__(self, *args): + super().__init__(*args) diff --git a/redis/auth/idp.py b/redis/auth/idp.py new file mode 100644 index 0000000000..0951d95641 --- /dev/null +++ b/redis/auth/idp.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod + +from redis.auth.token import TokenInterface + +""" +This interface is the facade of an identity provider +""" + + +class IdentityProviderInterface(ABC): + """ + Receive a token from the identity provider. + Receiving a token only works when being authenticated. + """ + + @abstractmethod + def request_token(self, force_refresh=False) -> TokenInterface: + pass + + +class IdentityProviderConfigInterface(ABC): + """ + Configuration class that provides a configured identity provider. + """ + + @abstractmethod + def get_provider(self) -> IdentityProviderInterface: + pass diff --git a/redis/auth/token.py b/redis/auth/token.py new file mode 100644 index 0000000000..876e95c4fa --- /dev/null +++ b/redis/auth/token.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timezone + +import jwt +from redis.auth.err import InvalidTokenSchemaErr + + +class TokenInterface(ABC): + @abstractmethod + def is_expired(self) -> bool: + pass + + @abstractmethod + def ttl(self) -> float: + pass + + @abstractmethod + def try_get(self, key: str) -> str: + pass + + @abstractmethod + def get_value(self) -> str: + pass + + @abstractmethod + def get_expires_at_ms(self) -> float: + pass + + @abstractmethod + def get_received_at_ms(self) -> float: + pass + + +class TokenResponse: + def __init__(self, token: TokenInterface): + self._token = token + + def get_token(self) -> TokenInterface: + return self._token + + def get_ttl_ms(self) -> float: + return self._token.get_expires_at_ms() - self._token.get_received_at_ms() + + +class SimpleToken(TokenInterface): + def __init__( + self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict + ) -> None: + self.value = value + self.expires_at = expires_at_ms + self.received_at = received_at_ms + self.claims = claims + + def ttl(self) -> float: + if self.expires_at == -1: + return -1 + + return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000) + + def is_expired(self) -> bool: + if self.expires_at == -1: + return False + + return self.ttl() <= 0 + + def try_get(self, key: str) -> str: + return self.claims.get(key) + + def get_value(self) -> str: + return self.value + + def get_expires_at_ms(self) -> float: + return self.expires_at + + def get_received_at_ms(self) -> float: + return self.received_at + + +class JWToken(TokenInterface): + + REQUIRED_FIELDS = {"exp"} + + def __init__(self, token: str): + self._value = token + self._decoded = jwt.decode( + self._value, + options={"verify_signature": False}, + algorithms=[jwt.get_unverified_header(self._value).get("alg")], + ) + self._validate_token() + + def is_expired(self) -> bool: + exp = self._decoded["exp"] + if exp == -1: + return False + + return ( + self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000 + ) + + def ttl(self) -> float: + exp = self._decoded["exp"] + if exp == -1: + return -1 + + return ( + self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000 + ) + + def try_get(self, key: str) -> str: + return self._decoded.get(key) + + def get_value(self) -> str: + return self._value + + def get_expires_at_ms(self) -> float: + return float(self._decoded["exp"] * 1000) + + def get_received_at_ms(self) -> float: + return datetime.now(timezone.utc).timestamp() * 1000 + + def _validate_token(self): + actual_fields = {x for x in self._decoded.keys()} + + if len(self.REQUIRED_FIELDS - actual_fields) != 0: + raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields) diff --git a/redis/auth/token_manager.py b/redis/auth/token_manager.py new file mode 100644 index 0000000000..dd8d16233d --- /dev/null +++ b/redis/auth/token_manager.py @@ -0,0 +1,370 @@ +import asyncio +import logging +import threading +from datetime import datetime, timezone +from time import sleep +from typing import Any, Awaitable, Callable, Union + +from redis.auth.err import RequestTokenErr, TokenRenewalErr +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token import TokenResponse + +logger = logging.getLogger(__name__) + + +class CredentialsListener: + """ + Listeners that will be notified on events related to credentials. + Accepts callbacks and awaitable callbacks. + """ + + def __init__(self): + self._on_next = None + self._on_error = None + + @property + def on_next(self) -> Union[Callable[[Any], None], Awaitable]: + return self._on_next + + @on_next.setter + def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None: + self._on_next = callback + + @property + def on_error(self) -> Union[Callable[[Exception], None], Awaitable]: + return self._on_error + + @on_error.setter + def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None: + self._on_error = callback + + +class RetryPolicy: + def __init__(self, max_attempts: int, delay_in_ms: float): + self.max_attempts = max_attempts + self.delay_in_ms = delay_in_ms + + def get_max_attempts(self) -> int: + """ + Retry attempts before exception will be thrown. + + :return: int + """ + return self.max_attempts + + def get_delay_in_ms(self) -> float: + """ + Delay between retries in seconds. + + :return: int + """ + return self.delay_in_ms + + +class TokenManagerConfig: + def __init__( + self, + expiration_refresh_ratio: float, + lower_refresh_bound_millis: int, + token_request_execution_timeout_in_ms: int, + retry_policy: RetryPolicy, + ): + self._expiration_refresh_ratio = expiration_refresh_ratio + self._lower_refresh_bound_millis = lower_refresh_bound_millis + self._token_request_execution_timeout_in_ms = ( + token_request_execution_timeout_in_ms + ) + self._retry_policy = retry_policy + + def get_expiration_refresh_ratio(self) -> float: + """ + Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501 + For example, a value of 0.75 means the token should be refreshed + when 75% of its lifetime has elapsed (or when 25% of its lifetime remains). + + :return: float + """ + + return self._expiration_refresh_ratio + + def get_lower_refresh_bound_millis(self) -> int: + """ + Represents the minimum time in milliseconds before token expiration + to trigger a refresh, in milliseconds. + This value sets a fixed lower bound for when a token refresh should occur, + regardless of the token's total lifetime. + If set to 0 there will be no lower bound and the refresh will be triggered + based on the expirationRefreshRatio only. + + :return: int + """ + return self._lower_refresh_bound_millis + + def get_token_request_execution_timeout_in_ms(self) -> int: + """ + Represents the maximum time in milliseconds to wait + for a token request to complete. + + :return: int + """ + return self._token_request_execution_timeout_in_ms + + def get_retry_policy(self) -> RetryPolicy: + """ + Represents the retry policy for token requests. + + :return: RetryPolicy + """ + return self._retry_policy + + +class TokenManager: + def __init__( + self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig + ): + self._idp = identity_provider + self._config = config + self._next_timer = None + self._listener = None + self._init_timer = None + self._retries = 0 + + def __del__(self): + logger.info("Token manager are disposed") + self.stop() + + def start( + self, + listener: CredentialsListener, + skip_initial: bool = False, + ) -> Callable[[], None]: + self._listener = listener + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=_start_event_loop_in_thread, args=(loop,), daemon=True + ) + thread.start() + + # Event to block for initial execution. + init_event = asyncio.Event() + self._init_timer = loop.call_later( + 0, self._renew_token, skip_initial, init_event + ) + logger.info("Token manager started") + + # Blocks in thread-safe manner. + asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result() + return self.stop + + async def start_async( + self, + listener: CredentialsListener, + block_for_initial: bool = False, + initial_delay_in_ms: float = 0, + skip_initial: bool = False, + ) -> Callable[[], None]: + self._listener = listener + + loop = asyncio.get_running_loop() + init_event = asyncio.Event() + + # Wraps the async callback with async wrapper to schedule with loop.call_later() + wrapped = _async_to_sync_wrapper( + loop, self._renew_token_async, skip_initial, init_event + ) + self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped) + logger.info("Token manager started") + + if block_for_initial: + await init_event.wait() + + return self.stop + + def stop(self): + if self._init_timer is not None: + self._init_timer.cancel() + if self._next_timer is not None: + self._next_timer.cancel() + + def acquire_token(self, force_refresh=False) -> TokenResponse: + try: + token = self._idp.request_token(force_refresh) + except RequestTokenErr as e: + if self._retries < self._config.get_retry_policy().get_max_attempts(): + self._retries += 1 + sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000) + return self.acquire_token(force_refresh) + else: + raise e + + self._retries = 0 + return TokenResponse(token) + + async def acquire_token_async(self, force_refresh=False) -> TokenResponse: + try: + token = self._idp.request_token(force_refresh) + except RequestTokenErr as e: + if self._retries < self._config.get_retry_policy().get_max_attempts(): + self._retries += 1 + await asyncio.sleep( + self._config.get_retry_policy().get_delay_in_ms() / 1000 + ) + return await self.acquire_token_async(force_refresh) + else: + raise e + + self._retries = 0 + return TokenResponse(token) + + def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float: + delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date) + delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date) + delay = min(delay_for_ratio_refresh, delay_for_lower_refresh) + + return 0 if delay < 0 else delay / 1000 + + def _delay_for_lower_refresh(self, expire_date: float): + return ( + expire_date + - self._config.get_lower_refresh_bound_millis() + - (datetime.now(timezone.utc).timestamp() * 1000) + ) + + def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float): + token_ttl = expire_date - issue_date + refresh_before = token_ttl - ( + token_ttl * self._config.get_expiration_refresh_ratio() + ) + + return ( + expire_date + - refresh_before + - (datetime.now(timezone.utc).timestamp() * 1000) + ) + + def _renew_token( + self, skip_initial: bool = False, init_event: asyncio.Event = None + ): + """ + Task to renew token from identity provider. + Schedules renewal tasks based on token TTL. + """ + + try: + token_res = self.acquire_token(force_refresh=True) + delay = self._calculate_renewal_delay( + token_res.get_token().get_expires_at_ms(), + token_res.get_token().get_received_at_ms(), + ) + + if token_res.get_token().is_expired(): + raise TokenRenewalErr("Requested token is expired") + + if self._listener.on_next is None: + logger.warning( + "No registered callback for token renewal task. Renewal cancelled" + ) + return + + if not skip_initial: + try: + self._listener.on_next(token_res.get_token()) + except Exception as e: + raise TokenRenewalErr(e) + + if delay <= 0: + return + + loop = asyncio.get_running_loop() + self._next_timer = loop.call_later(delay, self._renew_token) + logger.info(f"Next token renewal scheduled in {delay} seconds") + return token_res + except Exception as e: + if self._listener.on_error is None: + raise e + + self._listener.on_error(e) + finally: + if init_event: + init_event.set() + + async def _renew_token_async( + self, skip_initial: bool = False, init_event: asyncio.Event = None + ): + """ + Async task to renew tokens from identity provider. + Schedules renewal tasks based on token TTL. + """ + + try: + token_res = await self.acquire_token_async(force_refresh=True) + delay = self._calculate_renewal_delay( + token_res.get_token().get_expires_at_ms(), + token_res.get_token().get_received_at_ms(), + ) + + if token_res.get_token().is_expired(): + raise TokenRenewalErr("Requested token is expired") + + if self._listener.on_next is None: + logger.warning( + "No registered callback for token renewal task. Renewal cancelled" + ) + return + + if not skip_initial: + try: + await self._listener.on_next(token_res.get_token()) + except Exception as e: + raise TokenRenewalErr(e) + + if delay <= 0: + return + + loop = asyncio.get_running_loop() + wrapped = _async_to_sync_wrapper(loop, self._renew_token_async) + logger.info(f"Next token renewal scheduled in {delay} seconds") + loop.call_later(delay, wrapped) + except Exception as e: + if self._listener.on_error is None: + raise e + + await self._listener.on_error(e) + finally: + if init_event: + init_event.set() + + +def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): + """ + Wraps an asynchronous function so it can be used with loop.call_later. + + :param loop: The event loop in which the coroutine will be executed. + :param coro_func: The coroutine function to wrap. + :param args: Positional arguments to pass to the coroutine function. + :param kwargs: Keyword arguments to pass to the coroutine function. + :return: A regular function suitable for loop.call_later. + """ + + def wrapped(): + # Schedule the coroutine in the event loop + asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) + + return wrapped + + +def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop): + """ + Starts event loop in a thread. + Used to be able to schedule tasks using loop.call_later. + + :param event_loop: + :return: + """ + asyncio.set_event_loop(event_loop) + event_loop.run_forever() diff --git a/redis/client.py b/redis/client.py index bf3432e7eb..a7c1364a10 100755 --- a/redis/client.py +++ b/redis/client.py @@ -27,6 +27,13 @@ UnixDomainSocketConnection, ) from redis.credentials import CredentialProvider +from redis.event import ( + AfterPooledConnectionsInstantiationEvent, + AfterPubSubConnectionInstantiationEvent, + AfterSingleConnectionInstantiationEvent, + ClientType, + EventDispatcher, +) from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -213,6 +220,7 @@ def __init__( protocol: Optional[int] = 2, cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, + event_dispatcher: Optional[EventDispatcher] = None, ) -> None: """ Initialize a new Redis client. @@ -227,6 +235,10 @@ def __init__( if `True`, connection pool is not used. In that case `Redis` instance use is not thread safe. """ + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher if not connection_pool: if charset is not None: warnings.warn( @@ -313,9 +325,19 @@ def __init__( } ) connection_pool = ConnectionPool(**kwargs) + self._event_dispatcher.dispatch( + AfterPooledConnectionsInstantiationEvent( + [connection_pool], ClientType.SYNC, credential_provider + ) + ) self.auto_close_connection_pool = True else: self.auto_close_connection_pool = False + self._event_dispatcher.dispatch( + AfterPooledConnectionsInstantiationEvent( + [connection_pool], ClientType.SYNC, credential_provider + ) + ) self.connection_pool = connection_pool @@ -325,9 +347,16 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") + self.single_connection_lock = threading.Lock() self.connection = None - if single_connection_client: + self._single_connection_client = single_connection_client + if self._single_connection_client: self.connection = self.connection_pool.get_connection("_") + self._event_dispatcher.dispatch( + AfterSingleConnectionInstantiationEvent( + self.connection, ClientType.SYNC, self.single_connection_lock + ) + ) self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) @@ -500,7 +529,9 @@ def pubsub(self, **kwargs): subscribe to channels and listen for messages that get published to them. """ - return PubSub(self.connection_pool, **kwargs) + return PubSub( + self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs + ) def monitor(self): return Monitor(self.connection_pool) @@ -563,6 +594,9 @@ def _execute_command(self, *args, **options): pool = self.connection_pool command_name = args[0] conn = self.connection or pool.get_connection(command_name, **options) + + if self._single_connection_client: + self.single_connection_lock.acquire() try: return conn.retry.call_with_retry( lambda: self._send_command_parse_response( @@ -571,6 +605,8 @@ def _execute_command(self, *args, **options): lambda error: self._disconnect_raise(conn, error), ) finally: + if self._single_connection_client: + self.single_connection_lock.release() if not self.connection: pool.release(conn) @@ -691,6 +727,7 @@ def __init__( ignore_subscribe_messages: bool = False, encoder: Optional["Encoder"] = None, push_handler_func: Union[None, Callable[[str], None]] = None, + event_dispatcher: Optional["EventDispatcher"] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -701,6 +738,11 @@ def __init__( # to lookup channel and pattern names for callback handlers. self.encoder = encoder self.push_handler_func = push_handler_func + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + self._lock = threading.Lock() if self.encoder is None: self.encoder = self.connection_pool.get_encoder() self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) @@ -791,11 +833,17 @@ def execute_command(self, *args): self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: self.connection._parser.set_pubsub_push_handler(self.push_handler_func) + self._event_dispatcher.dispatch( + AfterPubSubConnectionInstantiationEvent( + self.connection, self.connection_pool, ClientType.SYNC, self._lock + ) + ) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: self.clean_health_check_responses() - self._execute(connection, connection.send_command, *args, **kwargs) + with self._lock: + self._execute(connection, connection.send_command, *args, **kwargs) def clean_health_check_responses(self) -> None: """ diff --git a/redis/cluster.py b/redis/cluster.py index 9dcbad7fc1..38bd5dde1a 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -15,6 +15,12 @@ from redis.commands.helpers import list_or_args from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.event import ( + AfterPooledConnectionsInstantiationEvent, + AfterPubSubConnectionInstantiationEvent, + ClientType, + EventDispatcher, +) from redis.exceptions import ( AskError, AuthenticationError, @@ -505,6 +511,7 @@ def __init__( address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, + event_dispatcher: Optional[EventDispatcher] = None, **kwargs, ): """ @@ -638,6 +645,10 @@ def __init__( self.read_from_replicas = read_from_replicas self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, @@ -646,6 +657,7 @@ def __init__( address_remap=address_remap, cache=cache, cache_config=cache_config, + event_dispatcher=self._event_dispatcher, **kwargs, ) @@ -1332,6 +1344,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, cache_factory: Optional[CacheFactoryInterface] = None, + event_dispatcher: Optional[EventDispatcher] = None, **kwargs, ): self.nodes_cache = {} @@ -1353,6 +1366,13 @@ def __init__( if lock is None: lock = threading.Lock() self._lock = lock + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + self._credential_provider = self.connection_kwargs.get( + "credential_provider", None + ) self.initialize() def get_node(self, host=None, port=None, node_name=None): @@ -1479,11 +1499,19 @@ def create_redis_connections(self, nodes): """ This function will create a redis connection to all nodes in :nodes: """ + connection_pools = [] for node in nodes: if node.redis_connection is None: node.redis_connection = self.create_redis_node( host=node.host, port=node.port, **self.connection_kwargs ) + connection_pools.append(node.redis_connection.connection_pool) + + self._event_dispatcher.dispatch( + AfterPooledConnectionsInstantiationEvent( + connection_pools, ClientType.SYNC, self._credential_provider + ) + ) def create_redis_node(self, host, port, **kwargs): if self.from_url: @@ -1698,6 +1726,7 @@ def __init__( host=None, port=None, push_handler_func=None, + event_dispatcher: Optional["EventDispatcher"] = None, **kwargs, ): """ @@ -1723,10 +1752,15 @@ def __init__( self.cluster = redis_cluster self.node_pubsub_mapping = {} self._pubsubs_generator = self._pubsubs_generator() + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher super().__init__( connection_pool=connection_pool, encoder=redis_cluster.encoder, push_handler_func=push_handler_func, + event_dispatcher=self._event_dispatcher, **kwargs, ) @@ -1813,6 +1847,11 @@ def execute_command(self, *args): self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: self.connection._parser.set_pubsub_push_handler(self.push_handler_func) + self._event_dispatcher.dispatch( + AfterPubSubConnectionInstantiationEvent( + self.connection, self.connection_pool, ClientType.SYNC, self._lock + ) + ) connection = self.connection self._execute(connection, connection.send_command, *args) diff --git a/redis/connection.py b/redis/connection.py index 40f2d29722..9d29b4aba6 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -22,8 +22,10 @@ ) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser +from .auth.token import TokenInterface from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider +from .event import AfterConnectionReleasedEvent, EventDispatcher from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -151,6 +153,10 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass + @abstractmethod + def get_protocol(self): + pass + @abstractmethod def connect(self): pass @@ -202,6 +208,14 @@ def pack_commands(self, commands): def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: pass + @abstractmethod + def set_re_auth_token(self, token: TokenInterface): + pass + + @abstractmethod + def re_auth(self): + pass + class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" @@ -229,6 +243,7 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, + event_dispatcher: Optional[EventDispatcher] = None, ): """ Initialize a new Connection. @@ -244,6 +259,10 @@ def __init__( "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher self.pid = os.getpid() self.db = db self.client_name = client_name @@ -283,6 +302,7 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 + self._re_auth_token: Optional[TokenInterface] = None try: p = int(protocol) except TypeError: @@ -663,6 +683,19 @@ def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): self._handshake_metadata = value + def set_re_auth_token(self, token: TokenInterface): + self._re_auth_token = token + + def re_auth(self): + if self._re_auth_token is not None: + self.send_command( + "AUTH", + self._re_auth_token.try_get("oid"), + self._re_auth_token.get_value(), + ) + self.read_response() + self._re_auth_token = None + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -750,6 +783,7 @@ def __init__( self.retry = self._conn.retry self.host = self._conn.host self.port = self._conn.port + self.credential_provider = conn.credential_provider self._pool_lock = pool_lock self._cache = cache self._cache_lock = threading.Lock() @@ -933,6 +967,15 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]] else: self._cache.delete_by_redis_keys(data[1]) + def get_protocol(self): + return self._conn.get_protocol() + + def set_re_auth_token(self, token: TokenInterface): + self._conn.set_re_auth_token(token) + + def re_auth(self): + self._conn.re_auth() + class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -1318,6 +1361,10 @@ def __init__( connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) + if self._event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as # after a fork. during this time, multiple threads in the child @@ -1475,6 +1522,9 @@ def release(self, connection: "Connection") -> None: if self.owns_connection(connection): self._available_connections.append(connection) + self._event_dispatcher.dispatch( + AfterConnectionReleasedEvent(connection) + ) else: # pool doesn't own this connection. do not add it back # to the pool and decrement the count so that another @@ -1517,6 +1567,29 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry + def re_auth_callback(self, token: TokenInterface): + with self._lock: + for conn in self._available_connections: + conn.retry.call_with_retry( + lambda: conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ), + lambda error: self._mock(error), + ) + conn.retry.call_with_retry( + lambda: conn.read_response(), lambda error: self._mock(error) + ) + for conn in self._in_use_connections: + conn.set_re_auth_token(token) + + async def _mock(self, error: RedisError): + """ + Dummy functions, needs to be passed as error callback to retry object. + :param error: + :return: + """ + pass + class BlockingConnectionPool(ConnectionPool): """ diff --git a/redis/credentials.py b/redis/credentials.py index 7ba26dcde1..6e59454ed3 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -1,4 +1,8 @@ -from typing import Optional, Tuple, Union +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Tuple, Union + +logger = logging.getLogger(__name__) class CredentialProvider: @@ -9,6 +13,38 @@ class CredentialProvider: def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: raise NotImplementedError("get_credentials must be implemented") + async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]: + logger.warning( + "This method is added for backward compatability. " + "Please override it in your implementation." + ) + return self.get_credentials() + + +class StreamingCredentialProvider(CredentialProvider, ABC): + """ + Credential provider that streams credentials in the background. + """ + + @abstractmethod + def on_next(self, callback: Callable[[Any], None]): + """ + Specifies the callback that should be invoked + when the next credentials will be retrieved. + + :param callback: Callback with + :return: + """ + pass + + @abstractmethod + def on_error(self, callback: Callable[[Exception], None]): + pass + + @abstractmethod + def is_streaming(self) -> bool: + pass + class UsernamePasswordCredentialProvider(CredentialProvider): """ @@ -24,3 +60,6 @@ def get_credentials(self): if self.username: return self.username, self.password return (self.password,) + + async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]: + return self.get_credentials() diff --git a/redis/event.py b/redis/event.py new file mode 100644 index 0000000000..5cc6c0017c --- /dev/null +++ b/redis/event.py @@ -0,0 +1,394 @@ +import asyncio +import threading +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Optional, Union + +from redis.auth.token import TokenInterface +from redis.credentials import CredentialProvider, StreamingCredentialProvider + + +class EventListenerInterface(ABC): + """ + Represents a listener for given event object. + """ + + @abstractmethod + def listen(self, event: object): + pass + + +class AsyncEventListenerInterface(ABC): + """ + Represents an async listener for given event object. + """ + + @abstractmethod + async def listen(self, event: object): + pass + + +class EventDispatcherInterface(ABC): + """ + Represents a dispatcher that dispatches events to listeners + associated with given event. + """ + + @abstractmethod + def dispatch(self, event: object): + pass + + @abstractmethod + async def dispatch_async(self, event: object): + pass + + +class EventException(Exception): + """ + Exception wrapper that adds an event object into exception context. + """ + + def __init__(self, exception: Exception, event: object): + self.exception = exception + self.event = event + super().__init__(exception) + + +class EventDispatcher(EventDispatcherInterface): + # TODO: Make dispatcher to accept external mappings. + def __init__(self): + """ + Mapping should be extended for any new events or listeners to be added. + """ + self._event_listeners_mapping = { + AfterConnectionReleasedEvent: [ + ReAuthConnectionListener(), + ], + AfterPooledConnectionsInstantiationEvent: [ + RegisterReAuthForPooledConnections() + ], + AfterSingleConnectionInstantiationEvent: [ + RegisterReAuthForSingleConnection() + ], + AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()], + AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()], + AsyncAfterConnectionReleasedEvent: [ + AsyncReAuthConnectionListener(), + ], + } + + def dispatch(self, event: object): + listeners = self._event_listeners_mapping.get(type(event)) + + for listener in listeners: + listener.listen(event) + + async def dispatch_async(self, event: object): + listeners = self._event_listeners_mapping.get(type(event)) + + for listener in listeners: + await listener.listen(event) + + +class AfterConnectionReleasedEvent: + """ + Event that will be fired before each command execution. + """ + + def __init__(self, connection): + self._connection = connection + + @property + def connection(self): + return self._connection + + +class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent): + pass + + +class ClientType(Enum): + SYNC = ("sync",) + ASYNC = ("async",) + + +class AfterPooledConnectionsInstantiationEvent: + """ + Event that will be fired after pooled connection instances was created. + """ + + def __init__( + self, + connection_pools: List, + client_type: ClientType, + credential_provider: Optional[CredentialProvider] = None, + ): + self._connection_pools = connection_pools + self._client_type = client_type + self._credential_provider = credential_provider + + @property + def connection_pools(self): + return self._connection_pools + + @property + def client_type(self) -> ClientType: + return self._client_type + + @property + def credential_provider(self) -> Union[CredentialProvider, None]: + return self._credential_provider + + +class AfterSingleConnectionInstantiationEvent: + """ + Event that will be fired after single connection instances was created. + + :param connection_lock: For sync client thread-lock should be provided, + for async asyncio.Lock + """ + + def __init__( + self, + connection, + client_type: ClientType, + connection_lock: Union[threading.Lock, asyncio.Lock], + ): + self._connection = connection + self._client_type = client_type + self._connection_lock = connection_lock + + @property + def connection(self): + return self._connection + + @property + def client_type(self) -> ClientType: + return self._client_type + + @property + def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]: + return self._connection_lock + + +class AfterPubSubConnectionInstantiationEvent: + def __init__( + self, + pubsub_connection, + connection_pool, + client_type: ClientType, + connection_lock: Union[threading.Lock, asyncio.Lock], + ): + self._pubsub_connection = pubsub_connection + self._connection_pool = connection_pool + self._client_type = client_type + self._connection_lock = connection_lock + + @property + def pubsub_connection(self): + return self._pubsub_connection + + @property + def connection_pool(self): + return self._connection_pool + + @property + def client_type(self) -> ClientType: + return self._client_type + + @property + def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]: + return self._connection_lock + + +class AfterAsyncClusterInstantiationEvent: + """ + Event that will be fired after async cluster instance was created. + + Async cluster doesn't use connection pools, + instead ClusterNode object manages connections. + """ + + def __init__( + self, + nodes: dict, + credential_provider: Optional[CredentialProvider] = None, + ): + self._nodes = nodes + self._credential_provider = credential_provider + + @property + def nodes(self) -> dict: + return self._nodes + + @property + def credential_provider(self) -> Union[CredentialProvider, None]: + return self._credential_provider + + +class ReAuthConnectionListener(EventListenerInterface): + """ + Listener that performs re-authentication of given connection. + """ + + def listen(self, event: AfterConnectionReleasedEvent): + event.connection.re_auth() + + +class AsyncReAuthConnectionListener(AsyncEventListenerInterface): + """ + Async listener that performs re-authentication of given connection. + """ + + async def listen(self, event: AsyncAfterConnectionReleasedEvent): + await event.connection.re_auth() + + +class RegisterReAuthForPooledConnections(EventListenerInterface): + """ + Listener that registers a re-authentication callback for pooled connections. + Required by :class:`StreamingCredentialProvider`. + """ + + def __init__(self): + self._event = None + + def listen(self, event: AfterPooledConnectionsInstantiationEvent): + if isinstance(event.credential_provider, StreamingCredentialProvider): + self._event = event + + if event.client_type == ClientType.SYNC: + event.credential_provider.on_next(self._re_auth) + event.credential_provider.on_error(self._raise_on_error) + else: + event.credential_provider.on_next(self._re_auth_async) + event.credential_provider.on_error(self._raise_on_error_async) + + def _re_auth(self, token): + for pool in self._event.connection_pools: + pool.re_auth_callback(token) + + async def _re_auth_async(self, token): + for pool in self._event.connection_pools: + await pool.re_auth_callback(token) + + def _raise_on_error(self, error: Exception): + raise EventException(error, self._event) + + async def _raise_on_error_async(self, error: Exception): + raise EventException(error, self._event) + + +class RegisterReAuthForSingleConnection(EventListenerInterface): + """ + Listener that registers a re-authentication callback for single connection. + Required by :class:`StreamingCredentialProvider`. + """ + + def __init__(self): + self._event = None + + def listen(self, event: AfterSingleConnectionInstantiationEvent): + if isinstance( + event.connection.credential_provider, StreamingCredentialProvider + ): + self._event = event + + if event.client_type == ClientType.SYNC: + event.connection.credential_provider.on_next(self._re_auth) + event.connection.credential_provider.on_error(self._raise_on_error) + else: + event.connection.credential_provider.on_next(self._re_auth_async) + event.connection.credential_provider.on_error( + self._raise_on_error_async + ) + + def _re_auth(self, token): + with self._event.connection_lock: + self._event.connection.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ) + self._event.connection.read_response() + + async def _re_auth_async(self, token): + async with self._event.connection_lock: + await self._event.connection.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ) + await self._event.connection.read_response() + + def _raise_on_error(self, error: Exception): + raise EventException(error, self._event) + + async def _raise_on_error_async(self, error: Exception): + raise EventException(error, self._event) + + +class RegisterReAuthForAsyncClusterNodes(EventListenerInterface): + def __init__(self): + self._event = None + + def listen(self, event: AfterAsyncClusterInstantiationEvent): + if isinstance(event.credential_provider, StreamingCredentialProvider): + self._event = event + event.credential_provider.on_next(self._re_auth) + event.credential_provider.on_error(self._raise_on_error) + + async def _re_auth(self, token: TokenInterface): + for key in self._event.nodes: + await self._event.nodes[key].re_auth_callback(token) + + async def _raise_on_error(self, error: Exception): + raise EventException(error, self._event) + + +class RegisterReAuthForPubSub(EventListenerInterface): + def __init__(self): + self._connection = None + self._connection_pool = None + self._client_type = None + self._connection_lock = None + self._event = None + + def listen(self, event: AfterPubSubConnectionInstantiationEvent): + if isinstance( + event.pubsub_connection.credential_provider, StreamingCredentialProvider + ) and event.pubsub_connection.get_protocol() in [3, "3"]: + self._event = event + self._connection = event.pubsub_connection + self._connection_pool = event.connection_pool + self._client_type = event.client_type + self._connection_lock = event.connection_lock + + if self._client_type == ClientType.SYNC: + self._connection.credential_provider.on_next(self._re_auth) + self._connection.credential_provider.on_error(self._raise_on_error) + else: + self._connection.credential_provider.on_next(self._re_auth_async) + self._connection.credential_provider.on_error( + self._raise_on_error_async + ) + + def _re_auth(self, token: TokenInterface): + with self._connection_lock: + self._connection.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ) + self._connection.read_response() + + self._connection_pool.re_auth_callback(token) + + async def _re_auth_async(self, token: TokenInterface): + async with self._connection_lock: + await self._connection.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ) + await self._connection.read_response() + + await self._connection_pool.re_auth_callback(token) + + def _raise_on_error(self, error: Exception): + raise EventException(error, self._event) + + async def _raise_on_error_async(self, error: Exception): + raise EventException(error, self._event) diff --git a/requirements.txt b/requirements.txt index 622f70b810..9760e5bb13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -async-timeout>=4.0.3 \ No newline at end of file +async-timeout>=4.0.3 +PyJWT~=2.9.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 8036b64066..ee3a7c2023 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ "redis", "redis._parsers", "redis.asyncio", + "redis.auth", "redis.commands", "redis.commands.bf", "redis.commands.json", diff --git a/tests/conftest.py b/tests/conftest.py index 7c65898856..a900cea8bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,22 @@ import argparse +import json +import os import random import time +from datetime import datetime, timezone +from enum import Enum from typing import Callable, TypeVar from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse +import jwt import pytest import redis from packaging.version import Version from redis import Sentinel +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token import JWToken from redis.backoff import NoBackoff from redis.cache import ( CacheConfig, @@ -19,8 +26,16 @@ EvictionPolicy, ) from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url +from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry +from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.identity_provider import ( + ManagedIdentityIdType, + ManagedIdentityType, + create_provider_from_managed_identity, + create_provider_from_service_principal, +) from tests.ssl_utils import get_tls_certificates REDIS_INFO = {} @@ -36,6 +51,11 @@ _TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] +class AuthType(Enum): + MANAGED_IDENTITY = "managed_identity" + SERVICE_PRINCIPAL = "service_principal" + + # Taken from python3.9 class BooleanOptionalAction(argparse.Action): def __init__( @@ -149,6 +169,13 @@ def pytest_addoption(parser): help="Name of the Redis master service that the sentinels are monitoring", ) + parser.addoption( + "--endpoint-name", + action="store", + default=None, + help="Name of the Redis endpoint the tests should be executed on", + ) + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) @@ -296,12 +323,14 @@ def skip_ifnot_redis_enterprise() -> _TestDecorator: def skip_if_nocryptography() -> _TestDecorator: - try: - import cryptography # noqa - - return pytest.mark.skipif(False, reason="Cryptography dependency found") - except ImportError: - return pytest.mark.skipif(True, reason="No cryptography dependency") + # try: + # import cryptography # noqa + # + # return pytest.mark.skipif(False, reason="Cryptography dependency found") + # except ImportError: + # TODO: Because JWT library depends on cryptography, + # now it's always true and tests should be fixed + return pytest.mark.skipif(True, reason="No cryptography dependency") def skip_if_cryptography() -> _TestDecorator: @@ -575,6 +604,142 @@ def cache_key(request) -> CacheKey: return CacheKey(command, keys) +def mock_identity_provider() -> IdentityProviderInterface: + mock_provider = Mock(spec=IdentityProviderInterface) + token = {"exp": datetime.now(timezone.utc).timestamp() + 3600, "oid": "username"} + encoded = jwt.encode(token, "secret", algorithm="HS256") + jwt_token = JWToken(encoded) + mock_provider.request_token.return_value = jwt_token + return mock_provider + + +def identity_provider(request) -> IdentityProviderInterface: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + if request.param.get("mock_idp", None) is not None: + return mock_identity_provider() + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == "MANAGED_IDENTITY": + return _get_managed_identity_provider(request) + + return _get_service_principal_provider(request) + + +def _get_managed_identity_provider(request): + authority = os.getenv("AZURE_AUTHORITY") + resource = os.getenv("AZURE_RESOURCE") + id_value = os.getenv("AZURE_ID_VALUE", None) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + + return create_provider_from_managed_identity( + identity_type=identity_type, + resource=resource, + id_type=id_type, + id_value=id_value, + authority=authority, + **kwargs, + ) + + +def _get_service_principal_provider(request): + client_id = os.getenv("AZURE_CLIENT_ID") + client_credential = os.getenv("AZURE_CLIENT_SECRET") + authority = os.getenv("AZURE_AUTHORITY") + scopes = os.getenv("AZURE_REDIS_SCOPES", []) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + timeout = request.param.get("timeout", None) + else: + kwargs = {} + token_kwargs = {} + timeout = None + + if isinstance(scopes, str): + scopes = scopes.split(",") + + return create_provider_from_service_principal( + client_id=client_id, + client_credential=client_credential, + scopes=scopes, + timeout=timeout, + token_kwargs=token_kwargs, + authority=authority, + **kwargs, + ) + + +def get_credential_provider(request) -> CredentialProvider: + cred_provider_class = request.param.get("cred_provider_class") + cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) + + if cred_provider_class != EntraIdCredentialsProvider: + return cred_provider_class(**cred_provider_kwargs) + + idp = identity_provider(request) + initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) + block_for_initial = cred_provider_kwargs.get("block_for_initial", False) + expiration_refresh_ratio = cred_provider_kwargs.get( + "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + ) + lower_refresh_bound_millis = cred_provider_kwargs.get( + "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS + ) + max_attempts = cred_provider_kwargs.get( + "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + ) + delay_in_ms = cred_provider_kwargs.get( + "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + ) + + auth_config = TokenAuthConfig(idp) + auth_config.expiration_refresh_ratio = expiration_refresh_ratio + auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis + auth_config.max_attempts = max_attempts + auth_config.delay_in_ms = delay_in_ms + + return EntraIdCredentialsProvider( + config=auth_config, + initial_delay_in_ms=initial_delay_in_ms, + block_for_initial=block_for_initial, + ) + + +@pytest.fixture() +def credential_provider(request) -> CredentialProvider: + return get_credential_provider(request) + + +def get_endpoint(endpoint_name: str): + endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) + + if not (endpoints_config and os.path.exists(endpoints_config)): + raise FileNotFoundError(f"Endpoints config file not found: {endpoints_config}") + + try: + with open(endpoints_config, "r") as f: + data = json.load(f) + db = data[endpoint_name] + return db["endpoints"][0] + except Exception as e: + raise ValueError( + f"Failed to load endpoints config file: {endpoints_config}" + ) from e + + def wait_for_command(client, monitor, command, key=None): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 41b47b2268..8833426af1 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,21 +1,41 @@ +import os import random from contextlib import asynccontextmanager as _asynccontextmanager +from datetime import datetime, timezone +from enum import Enum from typing import Union +import jwt import pytest import pytest_asyncio import redis.asyncio as redis +from mock.mock import Mock from packaging.version import Version from redis.asyncio import Sentinel from redis.asyncio.client import Monitor from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token import JWToken from redis.backoff import NoBackoff +from redis.credentials import CredentialProvider +from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.identity_provider import ( + ManagedIdentityIdType, + ManagedIdentityType, + create_provider_from_managed_identity, + create_provider_from_service_principal, +) from tests.conftest import REDIS_INFO from .compat import mock +class AuthType(Enum): + MANAGED_IDENTITY = "managed_identity" + SERVICE_PRINCIPAL = "service_principal" + + async def _get_info(redis_url): client = redis.Redis.from_url(redis_url) info = await client.info() @@ -216,6 +236,125 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): yield mocked +def mock_identity_provider() -> IdentityProviderInterface: + mock_provider = Mock(spec=IdentityProviderInterface) + token = {"exp": datetime.now(timezone.utc).timestamp() + 3600, "oid": "username"} + encoded = jwt.encode(token, "secret", algorithm="HS256") + jwt_token = JWToken(encoded) + mock_provider.request_token.return_value = jwt_token + return mock_provider + + +def identity_provider(request) -> IdentityProviderInterface: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + if request.param.get("mock_idp", None) is not None: + return mock_identity_provider() + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == "MANAGED_IDENTITY": + return _get_managed_identity_provider(request) + + return _get_service_principal_provider(request) + + +def _get_managed_identity_provider(request): + authority = os.getenv("AZURE_AUTHORITY") + resource = os.getenv("AZURE_RESOURCE") + id_value = os.getenv("AZURE_ID_VALUE", None) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + + return create_provider_from_managed_identity( + identity_type=identity_type, + resource=resource, + id_type=id_type, + id_value=id_value, + authority=authority, + **kwargs, + ) + + +def _get_service_principal_provider(request): + client_id = os.getenv("AZURE_CLIENT_ID") + client_credential = os.getenv("AZURE_CLIENT_SECRET") + authority = os.getenv("AZURE_AUTHORITY") + scopes = os.getenv("AZURE_REDIS_SCOPES", []) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + timeout = request.param.get("timeout", None) + else: + kwargs = {} + token_kwargs = {} + timeout = None + + if isinstance(scopes, str): + scopes = scopes.split(",") + + return create_provider_from_service_principal( + client_id=client_id, + client_credential=client_credential, + scopes=scopes, + timeout=timeout, + token_kwargs=token_kwargs, + authority=authority, + **kwargs, + ) + + +def get_credential_provider(request) -> CredentialProvider: + cred_provider_class = request.param.get("cred_provider_class") + cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) + + if cred_provider_class != EntraIdCredentialsProvider: + return cred_provider_class(**cred_provider_kwargs) + + idp = identity_provider(request) + initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) + block_for_initial = cred_provider_kwargs.get("block_for_initial", False) + expiration_refresh_ratio = cred_provider_kwargs.get( + "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + ) + lower_refresh_bound_millis = cred_provider_kwargs.get( + "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS + ) + max_attempts = cred_provider_kwargs.get( + "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + ) + delay_in_ms = cred_provider_kwargs.get( + "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + ) + + auth_config = TokenAuthConfig(idp) + auth_config.expiration_refresh_ratio = expiration_refresh_ratio + auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis + auth_config.max_attempts = max_attempts + auth_config.delay_in_ms = delay_in_ms + + return EntraIdCredentialsProvider( + config=auth_config, + initial_delay_in_ms=initial_delay_in_ms, + block_for_initial=block_for_initial, + ) + + +@pytest_asyncio.fixture() +async def credential_provider(request) -> CredentialProvider: + return get_credential_provider(request) + + async def wait_for_command( client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None ): diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 477397dd5f..c95babf687 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2893,12 +2893,13 @@ class TestSSL: appropriate port. """ - CLIENT_CERT, CLIENT_KEY, CA_CERT = get_tls_certificates("cluster") - @pytest_asyncio.fixture() def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: ssl_url = request.config.option.redis_ssl_url ssl_host, ssl_port = urlparse(ssl_url)[1].split(":") + self.client_cert, self.client_key, self.ca_cert = get_tls_certificates( + "cluster" + ) async def _create_client(mocked: bool = True, **kwargs: Any) -> RedisCluster: if mocked: @@ -3017,24 +3018,24 @@ async def test_validating_self_signed_certificate( ) -> None: async with await create_client( ssl=True, - ssl_ca_certs=self.CA_CERT, + ssl_ca_certs=self.ca_cert, ssl_cert_reqs="required", - ssl_certfile=self.CLIENT_CERT, - ssl_keyfile=self.CLIENT_KEY, + ssl_certfile=self.client_cert, + ssl_keyfile=self.client_key, ) as rc: assert await rc.ping() async def test_validating_self_signed_string_certificate( self, create_client: Callable[..., Awaitable[RedisCluster]] ) -> None: - with open(self.CA_CERT) as f: + with open(self.ca_cert) as f: cert_data = f.read() async with await create_client( ssl=True, ssl_ca_data=cert_data, ssl_cert_reqs="required", - ssl_certfile=self.CLIENT_CERT, - ssl_keyfile=self.CLIENT_KEY, + ssl_certfile=self.client_cert, + ssl_keyfile=self.client_key, ) as rc: assert await rc.ping() diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 2f5bbfb621..83545b4ede 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -5,6 +5,7 @@ import pytest_asyncio import redis.asyncio as redis from redis.asyncio.connection import Connection, to_bool +from redis.auth.token import TokenInterface from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt from .compat import aclosing, mock @@ -106,6 +107,12 @@ async def disconnect(self): async def can_read_destructive(self, timeout: float = 0): return False + def set_re_auth_token(self, token: TokenInterface): + pass + + async def re_auth(self): + pass + class TestConnectionPool: @asynccontextmanager diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index 4429f7453b..ca42d19090 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -1,15 +1,51 @@ import functools import random import string +from asyncio import Lock as AsyncLock +from asyncio import sleep as async_sleep from typing import Optional, Tuple, Union import pytest import pytest_asyncio import redis -from redis import AuthenticationError, DataError, ResponseError +from mock.mock import Mock, call +from redis import AuthenticationError, DataError, RedisError, ResponseError +from redis.asyncio import Connection, ConnectionPool, Redis +from redis.asyncio.retry import Retry +from redis.auth.err import RequestTokenErr +from redis.backoff import NoBackoff from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider +from redis.exceptions import ConnectionError from redis.utils import str_if_bytes -from tests.conftest import skip_if_redis_enterprise +from redis_entraid.cred_provider import EntraIdCredentialsProvider +from tests.conftest import get_endpoint, skip_if_redis_enterprise +from tests.test_asyncio.conftest import get_credential_provider + + +@pytest.fixture() +def endpoint(request): + endpoint_name = request.config.getoption("--endpoint-name") + + try: + return get_endpoint(endpoint_name) + except FileNotFoundError as e: + pytest.skip( + f"Skipping scenario test because endpoints file is missing: {str(e)}" + ) + + +@pytest_asyncio.fixture() +async def r_credential(request, create_redis, endpoint): + credential_provider = request.param.get("cred_provider_class", None) + + if credential_provider is not None: + credential_provider = get_credential_provider(request) + + kwargs = { + "credential_provider": credential_provider, + } + + return await create_redis(url=endpoint, **kwargs) @pytest_asyncio.fixture() @@ -281,3 +317,380 @@ async def test_user_pass_provider_only_password( r2 = await create_redis(flushdb=False, credential_provider=provider) assert await r2.auth(provider.password) is True assert await r2.ping() is True + + +@pytest.mark.asyncio +@pytest.mark.onlynoncluster +class TestStreamingCredentialProvider: + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + async def test_async_re_auth_all_connections(self, credential_provider): + mock_connection = Mock(spec=Connection) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_another_connection = Mock(spec=Connection) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection, mock_another_connection] + mock_pool._lock = AsyncLock() + auth_token = None + + async def re_auth_callback(token): + nonlocal auth_token + auth_token = token + async with mock_pool._lock: + for conn in mock_pool._available_connections: + await conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ) + await conn.read_response() + + mock_pool.re_auth_callback = re_auth_callback + + await Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + await credential_provider.get_credentials_async() + await async_sleep(0.5) + + mock_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + mock_another_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + async def test_async_re_auth_partial_connections(self, credential_provider): + mock_connection = Mock(spec=Connection) + mock_connection.retry = Retry(NoBackoff(), 3) + mock_another_connection = Mock(spec=Connection) + mock_another_connection.retry = Retry(NoBackoff(), 3) + mock_failed_connection = Mock(spec=Connection) + mock_failed_connection.read_response.side_effect = ConnectionError( + "Failed auth" + ) + mock_failed_connection.retry = Retry(NoBackoff(), 3) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [ + mock_connection, + mock_another_connection, + mock_failed_connection, + ] + mock_pool._lock = AsyncLock() + + async def _raise(error: RedisError): + pass + + async def re_auth_callback(token): + async with mock_pool._lock: + for conn in mock_pool._available_connections: + await conn.retry.call_with_retry( + lambda: conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ), + lambda error: _raise(error), + ) + await conn.retry.call_with_retry( + lambda: conn.read_response(), lambda error: _raise(error) + ) + + mock_pool.re_auth_callback = re_auth_callback + + await Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + await credential_provider.get_credentials_async() + await async_sleep(0.5) + + mock_connection.read_response.assert_has_calls([call()]) + mock_another_connection.read_response.assert_has_calls([call()]) + mock_failed_connection.read_response.assert_has_calls([call(), call(), call()]) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + async def test_re_auth_pub_sub_in_resp3(self, credential_provider): + mock_pubsub_connection = Mock(spec=Connection) + mock_pubsub_connection.get_protocol.return_value = 3 + mock_pubsub_connection.credential_provider = credential_provider + mock_pubsub_connection.retry = Retry(NoBackoff(), 3) + mock_another_connection = Mock(spec=Connection) + mock_another_connection.retry = Retry(NoBackoff(), 3) + + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.side_effect = [ + mock_pubsub_connection, + mock_another_connection, + ] + mock_pool._available_connections = [mock_another_connection] + mock_pool._lock = AsyncLock() + auth_token = None + + async def re_auth_callback(token): + nonlocal auth_token + auth_token = token + async with mock_pool._lock: + for conn in mock_pool._available_connections: + await conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ) + await conn.read_response() + + mock_pool.re_auth_callback = re_auth_callback + + r = Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + p = r.pubsub() + await p.subscribe("test") + await credential_provider.get_credentials_async() + await async_sleep(0.5) + + mock_pubsub_connection.send_command.assert_has_calls( + [ + call("SUBSCRIBE", "test", check_health=True), + call("AUTH", auth_token.try_get("oid"), auth_token.get_value()), + ] + ) + mock_another_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + async def test_do_not_re_auth_pub_sub_in_resp2(self, credential_provider): + mock_pubsub_connection = Mock(spec=Connection) + mock_pubsub_connection.get_protocol.return_value = 2 + mock_pubsub_connection.credential_provider = credential_provider + mock_pubsub_connection.retry = Retry(NoBackoff(), 3) + mock_another_connection = Mock(spec=Connection) + mock_another_connection.retry = Retry(NoBackoff(), 3) + + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.side_effect = [ + mock_pubsub_connection, + mock_another_connection, + ] + mock_pool._available_connections = [mock_another_connection] + mock_pool._lock = AsyncLock() + auth_token = None + + async def re_auth_callback(token): + nonlocal auth_token + auth_token = token + async with mock_pool._lock: + for conn in mock_pool._available_connections: + await conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ) + await conn.read_response() + + mock_pool.re_auth_callback = re_auth_callback + + r = Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + p = r.pubsub() + await p.subscribe("test") + await credential_provider.get_credentials_async() + await async_sleep(0.5) + + mock_pubsub_connection.send_command.assert_has_calls( + [ + call("SUBSCRIBE", "test", check_health=True), + ] + ) + mock_another_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + async def test_fails_on_token_renewal(self, credential_provider): + credential_provider._token_mgr._idp.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + ] + mock_connection = Mock(spec=Connection) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_another_connection = Mock(spec=Connection) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection, mock_another_connection] + + await Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + with pytest.raises(RequestTokenErr): + await credential_provider.get_credentials() + + +@pytest.mark.asyncio +@pytest.mark.onlynoncluster +@pytest.mark.cp_integration +class TestEntraIdCredentialsProvider: + @pytest.mark.parametrize( + "r_credential", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"block_for_initial": True}, + }, + ], + ids=["blocked", "non-blocked"], + indirect=True, + ) + @pytest.mark.asyncio + @pytest.mark.onlynoncluster + @pytest.mark.cp_integration + async def test_async_auth_pool_with_credential_provider(self, r_credential: Redis): + assert await r_credential.ping() is True + + @pytest.mark.parametrize( + "r_credential", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"block_for_initial": True}, + }, + ], + ids=["blocked", "non-blocked"], + indirect=True, + ) + @pytest.mark.asyncio + @pytest.mark.onlynoncluster + @pytest.mark.cp_integration + async def test_async_pipeline_with_credential_provider(self, r_credential: Redis): + pipe = r_credential.pipeline() + + await pipe.set("key", "value") + await pipe.get("key") + + assert await pipe.execute() == [True, b"value"] + + @pytest.mark.parametrize( + "r_credential", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + }, + ], + indirect=True, + ) + @pytest.mark.asyncio + @pytest.mark.onlynoncluster + @pytest.mark.cp_integration + async def test_async_auth_pubsub_with_credential_provider( + self, r_credential: Redis + ): + p = r_credential.pubsub() + await p.subscribe("entraid") + + await r_credential.publish("entraid", "test") + await r_credential.publish("entraid", "test") + + msg1 = await p.get_message() + + assert msg1["type"] == "subscribe" + + +@pytest.mark.asyncio +@pytest.mark.onlycluster +@pytest.mark.cp_integration +class TestClusterEntraIdCredentialsProvider: + @pytest.mark.parametrize( + "r_credential", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"block_for_initial": True}, + }, + ], + ids=["blocked", "non-blocked"], + indirect=True, + ) + @pytest.mark.asyncio + @pytest.mark.onlycluster + @pytest.mark.cp_integration + async def test_async_auth_pool_with_credential_provider(self, r_credential: Redis): + assert await r_credential.ping() is True diff --git a/tests/test_auth/__init__.py b/tests/test_auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_auth/test_token.py b/tests/test_auth/test_token.py new file mode 100644 index 0000000000..978cc2ca8c --- /dev/null +++ b/tests/test_auth/test_token.py @@ -0,0 +1,76 @@ +from datetime import datetime, timezone + +import jwt +import pytest +from redis.auth.err import InvalidTokenSchemaErr +from redis.auth.token import JWToken, SimpleToken + + +class TestToken: + + def test_simple_token(self): + token = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"key": "value"}, + ) + + assert token.ttl() == pytest.approx(1000, 10) + assert token.is_expired() is False + assert token.try_get("key") == "value" + assert token.get_value() == "value" + assert token.get_expires_at_ms() == pytest.approx( + (datetime.now(timezone.utc).timestamp() * 1000) + 100, 10 + ) + assert token.get_received_at_ms() == pytest.approx( + (datetime.now(timezone.utc).timestamp() * 1000), 10 + ) + + token = SimpleToken( + "value", + -1, + (datetime.now(timezone.utc).timestamp() * 1000), + {"key": "value"}, + ) + + assert token.ttl() == -1 + assert token.is_expired() is False + assert token.get_expires_at_ms() == -1 + + def test_jwt_token(self): + token = { + "exp": datetime.now(timezone.utc).timestamp() + 100, + "iat": datetime.now(timezone.utc).timestamp(), + "key": "value", + } + encoded = jwt.encode(token, "secret", algorithm="HS256") + jwt_token = JWToken(encoded) + + assert jwt_token.ttl() == pytest.approx(100000, 10) + assert jwt_token.is_expired() is False + assert jwt_token.try_get("key") == "value" + assert jwt_token.get_value() == encoded + assert jwt_token.get_expires_at_ms() == pytest.approx( + (datetime.now(timezone.utc).timestamp() * 1000) + 100000, 10 + ) + assert jwt_token.get_received_at_ms() == pytest.approx( + (datetime.now(timezone.utc).timestamp() * 1000), 10 + ) + + token = { + "exp": -1, + "iat": datetime.now(timezone.utc).timestamp(), + "key": "value", + } + encoded = jwt.encode(token, "secret", algorithm="HS256") + jwt_token = JWToken(encoded) + + assert jwt_token.ttl() == -1 + assert jwt_token.is_expired() is False + assert jwt_token.get_expires_at_ms() == -1000 + + with pytest.raises(InvalidTokenSchemaErr): + token = {"key": "value"} + encoded = jwt.encode(token, "secret", algorithm="HS256") + JWToken(encoded) diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py new file mode 100644 index 0000000000..bb396e246c --- /dev/null +++ b/tests/test_auth/test_token_manager.py @@ -0,0 +1,566 @@ +import asyncio +from datetime import datetime, timezone +from time import sleep +from unittest.mock import Mock + +import pytest +from redis.auth.err import RequestTokenErr, TokenRenewalErr +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token import SimpleToken +from redis.auth.token_manager import ( + CredentialsListener, + RetryPolicy, + TokenManager, + TokenManagerConfig, +) + + +class TestTokenManager: + @pytest.mark.parametrize( + "exp_refresh_ratio,tokens_refreshed", + [ + (0.9, 2), + (0.28, 4), + ], + ids=[ + "Refresh ratio = 0.9, 2 tokens in 0,1 second", + "Refresh ratio = 0.28, 4 tokens in 0,1 second", + ], + ) + def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 130, + (datetime.now(timezone.utc).timestamp() * 1000) + 30, + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 160, + (datetime.now(timezone.utc).timestamp() * 1000) + 60, + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 190, + (datetime.now(timezone.utc).timestamp() * 1000) + 90, + {"oid": "test"}, + ), + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(exp_refresh_ratio, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + sleep(0.1) + + assert len(tokens) == tokens_refreshed + + @pytest.mark.parametrize( + "exp_refresh_ratio,tokens_refreshed", + [ + (0.9, 2), + (0.28, 4), + ], + ids=[ + "Refresh ratio = 0.9, 2 tokens in 0,1 second", + "Refresh ratio = 0.28, 4 tokens in 0,1 second", + ], + ) + @pytest.mark.asyncio + async def test_async_success_token_renewal( + self, exp_refresh_ratio, tokens_refreshed + ): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 130, + (datetime.now(timezone.utc).timestamp() * 1000) + 30, + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 160, + (datetime.now(timezone.utc).timestamp() * 1000) + 60, + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 190, + (datetime.now(timezone.utc).timestamp() * 1000) + 90, + {"oid": "test"}, + ), + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(exp_refresh_ratio, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + await asyncio.sleep(0.1) + + assert len(tokens) == tokens_refreshed + + @pytest.mark.parametrize( + "block_for_initial,tokens_acquired", + [ + (True, 1), + (False, 0), + ], + ids=[ + "Block for initial, callback will triggered once", + "Non blocked, callback wont be triggered", + ], + ) + @pytest.mark.asyncio + async def test_async_request_token_blocking_behaviour( + self, block_for_initial, tokens_acquired + ): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ) + + async def on_next(token): + nonlocal tokens + sleep(0.1) + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=block_for_initial) + + assert len(tokens) == tokens_acquired + + def test_token_renewal_with_skip_initial(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 120, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener, skip_initial=True) + # Should be less than a 0.1, or it will be flacky due to + # additional token renewal. + sleep(0.2) + + assert len(tokens) == 2 + + @pytest.mark.asyncio + async def test_async_token_renewal_with_skip_initial(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 120, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, skip_initial=True) + # Should be less than a 0.1, or it will be flacky + # due to additional token renewal. + await asyncio.sleep(0.2) + + assert len(tokens) == 2 + + def test_success_token_renewal_with_retry(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + # Should be less than a 0.1, or it will be flacky + # due to additional token renewal. + sleep(0.08) + + assert mock_provider.request_token.call_count > 0 + assert len(tokens) > 0 + + @pytest.mark.asyncio + async def test_async_success_token_renewal_with_retry(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ), + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = None + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + # Should be less than a 0.1, or it will be flacky + # due to additional token renewal. + await asyncio.sleep(0.08) + + assert mock_provider.request_token.call_count > 0 + assert len(tokens) > 0 + + def test_no_token_renewal_on_process_complete(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ) + + def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(0.9, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + sleep(0.2) + + assert len(tokens) == 1 + + @pytest.mark.asyncio + async def test_async_no_token_renewal_on_process_complete(self): + tokens = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ) + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(0.9, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + await asyncio.sleep(0.2) + + assert len(tokens) == 1 + + def test_failed_token_renewal_with_retry(self): + tokens = [] + exceptions = [] + + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + ] + + def on_next(token): + nonlocal tokens + tokens.append(token) + + def on_error(exception): + nonlocal exceptions + exceptions.append(exception) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + sleep(0.1) + + assert mock_provider.request_token.call_count == 4 + assert len(tokens) == 0 + assert len(exceptions) == 1 + + @pytest.mark.asyncio + async def test_async_failed_token_renewal_with_retry(self): + tokens = [] + exceptions = [] + + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + ] + + async def on_next(token): + nonlocal tokens + tokens.append(token) + + async def on_error(exception): + nonlocal exceptions + exceptions.append(exception) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(3, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + sleep(0.1) + + assert mock_provider.request_token.call_count == 4 + assert len(tokens) == 0 + assert len(exceptions) == 1 + + def test_failed_renewal_on_expired_token(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) - 100, + (datetime.now(timezone.utc).timestamp() * 1000) - 1000, + {"oid": "test"}, + ) + + def on_error(error: TokenRenewalErr): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Requested token is expired" + + @pytest.mark.asyncio + async def test_async_failed_renewal_on_expired_token(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) - 100, + (datetime.now(timezone.utc).timestamp() * 1000) - 1000, + {"oid": "test"}, + ) + + async def on_error(error: TokenRenewalErr): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Requested token is expired" + + def test_failed_renewal_on_callback_error(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ) + + def on_next(token): + raise Exception("Some exception") + + def on_error(error): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + mgr.start(mock_listener) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Some exception" + + @pytest.mark.asyncio + async def test_async_failed_renewal_on_callback_error(self): + errors = [] + mock_provider = Mock(spec=IdentityProviderInterface) + mock_provider.request_token.return_value = SimpleToken( + "value", + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, + (datetime.now(timezone.utc).timestamp() * 1000), + {"oid": "test"}, + ) + + async def on_next(token): + raise Exception("Some exception") + + async def on_error(error): + nonlocal errors + errors.append(error) + + mock_listener = Mock(spec=CredentialsListener) + mock_listener.on_next = on_next + mock_listener.on_error = on_error + + retry_policy = RetryPolicy(1, 10) + config = TokenManagerConfig(1, 0, 1000, retry_policy) + mgr = TokenManager(mock_provider, config) + await mgr.start_async(mock_listener, block_for_initial=True) + + assert len(errors) == 1 + assert isinstance(errors[0], TokenRenewalErr) + assert str(errors[0]) == "Some exception" diff --git a/tests/test_connection.py b/tests/test_connection.py index a58703e3b5..7683a1416d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -28,6 +28,7 @@ UnixDomainSocketConnection, parse_url, ) +from redis.credentials import UsernamePasswordCredentialProvider from redis.exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -441,6 +442,7 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.credential_provider = UsernamePasswordCredentialProvider() proxy_connection = CacheProxyConnection( mock_connection, cache, threading.Lock() @@ -457,6 +459,7 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.credential_provider = UsernamePasswordCredentialProvider() mock_cache.is_cachable.return_value = True mock_cache.get.side_effect = [ @@ -541,6 +544,7 @@ def test_triggers_invalidation_processing_on_another_connection( mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" + mock_connection.credential_provider = UsernamePasswordCredentialProvider() another_conn = copy.deepcopy(mock_connection) another_conn.can_read.side_effect = [True, False] diff --git a/tests/test_credentials.py b/tests/test_credentials.py index aade04e082..b0b79d305f 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -1,14 +1,58 @@ import functools import random import string +import threading +from time import sleep from typing import Optional, Tuple, Union import pytest import redis -from redis import AuthenticationError, DataError, ResponseError +from mock.mock import Mock, call +from redis import AuthenticationError, DataError, Redis, ResponseError +from redis.auth.err import RequestTokenErr +from redis.backoff import NoBackoff +from redis.connection import ConnectionInterface, ConnectionPool from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider +from redis.exceptions import ConnectionError, RedisError +from redis.retry import Retry from redis.utils import str_if_bytes -from tests.conftest import _get_client, skip_if_redis_enterprise +from redis_entraid.cred_provider import EntraIdCredentialsProvider +from tests.conftest import ( + _get_client, + get_credential_provider, + get_endpoint, + skip_if_redis_enterprise, +) + + +@pytest.fixture() +def endpoint(request): + endpoint_name = request.config.getoption("--endpoint-name") + + try: + return get_endpoint(endpoint_name) + except FileNotFoundError as e: + pytest.skip( + f"Skipping scenario test because endpoints file is missing: {str(e)}" + ) + + +@pytest.fixture() +def r_entra(request, endpoint): + credential_provider = request.param.get("cred_provider_class", None) + single_connection = request.param.get("single_connection_client", False) + + if credential_provider is not None: + credential_provider = get_credential_provider(request) + + with _get_client( + redis.Redis, + request, + credential_provider=credential_provider, + single_connection_client=single_connection, + from_url=endpoint, + ) as client: + yield client class NoPassCredProvider(CredentialProvider): @@ -248,3 +292,368 @@ def test_user_pass_provider_only_password(self, r, request): ) assert r2.auth(provider.password) is True assert r2.ping() is True + + +@pytest.mark.onlynoncluster +class TestStreamingCredentialProvider: + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + def test_re_auth_all_connections(self, credential_provider): + mock_connection = Mock(spec=ConnectionInterface) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_another_connection = Mock(spec=ConnectionInterface) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection, mock_another_connection] + mock_pool._lock = threading.Lock() + auth_token = None + + def re_auth_callback(token): + nonlocal auth_token + auth_token = token + with mock_pool._lock: + for conn in mock_pool._available_connections: + conn.send_command("AUTH", token.try_get("oid"), token.get_value()) + conn.read_response() + + mock_pool.re_auth_callback = re_auth_callback + + Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + credential_provider.get_credentials() + sleep(0.5) + + mock_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + mock_another_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + def test_re_auth_partial_connections(self, credential_provider): + mock_connection = Mock(spec=ConnectionInterface) + mock_connection.retry = Retry(NoBackoff(), 3) + mock_another_connection = Mock(spec=ConnectionInterface) + mock_another_connection.retry = Retry(NoBackoff(), 3) + mock_failed_connection = Mock(spec=ConnectionInterface) + mock_failed_connection.read_response.side_effect = ConnectionError( + "Failed auth" + ) + mock_failed_connection.retry = Retry(NoBackoff(), 3) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [ + mock_connection, + mock_another_connection, + mock_failed_connection, + ] + mock_pool._lock = threading.Lock() + + def _raise(error: RedisError): + pass + + def re_auth_callback(token): + with mock_pool._lock: + for conn in mock_pool._available_connections: + conn.retry.call_with_retry( + lambda: conn.send_command( + "AUTH", token.try_get("oid"), token.get_value() + ), + lambda error: _raise(error), + ) + conn.retry.call_with_retry( + lambda: conn.read_response(), lambda error: _raise(error) + ) + + mock_pool.re_auth_callback = re_auth_callback + + Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + credential_provider.get_credentials() + sleep(0.5) + + mock_connection.read_response.assert_has_calls([call()]) + mock_another_connection.read_response.assert_has_calls([call()]) + mock_failed_connection.read_response.assert_has_calls([call(), call(), call()]) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + def test_re_auth_pub_sub_in_resp3(self, credential_provider): + mock_pubsub_connection = Mock(spec=ConnectionInterface) + mock_pubsub_connection.get_protocol.return_value = 3 + mock_pubsub_connection.credential_provider = credential_provider + mock_pubsub_connection.retry = Retry(NoBackoff(), 3) + mock_another_connection = Mock(spec=ConnectionInterface) + mock_another_connection.retry = Retry(NoBackoff(), 3) + + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.side_effect = [ + mock_pubsub_connection, + mock_another_connection, + ] + mock_pool._available_connections = [mock_another_connection] + mock_pool._lock = threading.Lock() + auth_token = None + + def re_auth_callback(token): + nonlocal auth_token + auth_token = token + with mock_pool._lock: + for conn in mock_pool._available_connections: + conn.send_command("AUTH", token.try_get("oid"), token.get_value()) + conn.read_response() + + mock_pool.re_auth_callback = re_auth_callback + + r = Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + p = r.pubsub() + p.subscribe("test") + credential_provider.get_credentials() + sleep(0.5) + + mock_pubsub_connection.send_command.assert_has_calls( + [ + call("SUBSCRIBE", "test", check_health=True), + call("AUTH", auth_token.try_get("oid"), auth_token.get_value()), + ] + ) + mock_another_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + def test_do_not_re_auth_pub_sub_in_resp2(self, credential_provider): + mock_pubsub_connection = Mock(spec=ConnectionInterface) + mock_pubsub_connection.get_protocol.return_value = 2 + mock_pubsub_connection.credential_provider = credential_provider + mock_pubsub_connection.retry = Retry(NoBackoff(), 3) + mock_another_connection = Mock(spec=ConnectionInterface) + mock_another_connection.retry = Retry(NoBackoff(), 3) + + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.side_effect = [ + mock_pubsub_connection, + mock_another_connection, + ] + mock_pool._available_connections = [mock_another_connection] + mock_pool._lock = threading.Lock() + auth_token = None + + def re_auth_callback(token): + nonlocal auth_token + auth_token = token + with mock_pool._lock: + for conn in mock_pool._available_connections: + conn.send_command("AUTH", token.try_get("oid"), token.get_value()) + conn.read_response() + + mock_pool.re_auth_callback = re_auth_callback + + r = Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + p = r.pubsub() + p.subscribe("test") + credential_provider.get_credentials() + sleep(0.5) + + mock_pubsub_connection.send_command.assert_has_calls( + [ + call("SUBSCRIBE", "test", check_health=True), + ] + ) + mock_another_connection.send_command.assert_has_calls( + [call("AUTH", auth_token.try_get("oid"), auth_token.get_value())] + ) + + @pytest.mark.parametrize( + "credential_provider", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005}, + "mock_idp": True, + } + ], + indirect=True, + ) + def test_fails_on_token_renewal(self, credential_provider): + credential_provider._token_mgr._idp.request_token.side_effect = [ + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + RequestTokenErr, + ] + mock_connection = Mock(spec=ConnectionInterface) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_another_connection = Mock(spec=ConnectionInterface) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.connection_kwargs = { + "credential_provider": credential_provider, + } + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection, mock_another_connection] + mock_pool._lock = threading.Lock() + + Redis( + connection_pool=mock_pool, + credential_provider=credential_provider, + ) + + with pytest.raises(RequestTokenErr): + credential_provider.get_credentials() + + +@pytest.mark.onlynoncluster +@pytest.mark.cp_integration +class TestEntraIdCredentialsProvider: + @pytest.mark.parametrize( + "r_entra", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "single_connection_client": False, + }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "single_connection_client": True, + }, + ], + ids=["pool", "single"], + indirect=True, + ) + @pytest.mark.onlynoncluster + @pytest.mark.cp_integration + def test_auth_pool_with_credential_provider(self, r_entra: redis.Redis): + assert r_entra.ping() is True + + @pytest.mark.parametrize( + "r_entra", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "single_connection_client": False, + }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "single_connection_client": True, + }, + ], + ids=["pool", "single"], + indirect=True, + ) + @pytest.mark.onlynoncluster + @pytest.mark.cp_integration + def test_auth_pipeline_with_credential_provider(self, r_entra: redis.Redis): + pipe = r_entra.pipeline() + + pipe.set("key", "value") + pipe.get("key") + + assert pipe.execute() == [True, b"value"] + + @pytest.mark.parametrize( + "r_entra", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + }, + ], + indirect=True, + ) + @pytest.mark.onlynoncluster + @pytest.mark.cp_integration + def test_auth_pubsub_with_credential_provider(self, r_entra: redis.Redis): + p = r_entra.pubsub() + p.subscribe("entraid") + + r_entra.publish("entraid", "test") + r_entra.publish("entraid", "test") + + assert p.get_message()["type"] == "subscribe" + assert p.get_message()["type"] == "message" + + +@pytest.mark.onlycluster +@pytest.mark.cp_integration +class TestClusterEntraIdCredentialsProvider: + @pytest.mark.parametrize( + "r_entra", + [ + { + "cred_provider_class": EntraIdCredentialsProvider, + "single_connection_client": False, + }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "single_connection_client": True, + }, + ], + ids=["pool", "single"], + indirect=True, + ) + @pytest.mark.onlycluster + @pytest.mark.cp_integration + def test_auth_pool_with_credential_provider(self, r_entra: redis.Redis): + assert r_entra.ping() is True