Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingCredentialProvider support #3445

Merged
merged 57 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
0a8f770
Added StreamingCredentialProvider interface
vladvildanov Nov 19, 2024
8272c73
StreamingCredentialProvider support
vladvildanov Nov 27, 2024
7021c7b
Removed debug statement
vladvildanov Nov 27, 2024
400ba2a
Changed an approach to handle multiple connection pools
vladvildanov Dec 3, 2024
9561032
Added support for RedisCluster
vladvildanov Dec 3, 2024
dfcd488
Merge branch 'master' of github.com:redis/redis-py into vv-tba-support
vladvildanov Dec 4, 2024
833968d
Added dispatching of custom connection pool
vladvildanov Dec 4, 2024
3848b57
Extended CredentialProvider interface with async API
vladvildanov Dec 5, 2024
fa9bc3c
Changed method implementation
vladvildanov Dec 5, 2024
1776679
Added support for async API
vladvildanov Dec 9, 2024
87a1ffa
Removed unused lock
vladvildanov Dec 9, 2024
24714ae
Added async API
vladvildanov Dec 10, 2024
0327f36
Merge branch 'master' of github.com:redis/redis-py into vv-tba-support
vladvildanov Dec 10, 2024
6dae71b
Added support for single connection client
vladvildanov Dec 11, 2024
32fc374
Added core functionality
vladvildanov Dec 11, 2024
c2eef78
Revert debug call
vladvildanov Dec 11, 2024
1a1b211
Added package to setup.py
vladvildanov Dec 11, 2024
974ad4f
Added handling of in-use connections
vladvildanov Dec 12, 2024
66a53ea
Added testing
vladvildanov Dec 12, 2024
2cad8b0
Changed fixture name
vladvildanov Dec 12, 2024
7eb6600
Added marker
vladvildanov Dec 12, 2024
5facdae
Marked tests with correct annotations
vladvildanov Dec 13, 2024
ee2ce1a
Added better cancelation handling
vladvildanov Dec 13, 2024
835ede7
Removed another annotation
vladvildanov Dec 16, 2024
e14d680
Added support for async cluster
vladvildanov Dec 16, 2024
90204e7
Added pipeline tests
vladvildanov Dec 17, 2024
0de0f4d
Added support for Pub/Sub
vladvildanov Dec 17, 2024
46e2f94
Added support for Pub/Sub in cluster
vladvildanov Dec 18, 2024
5488726
Added an option to parse endpoint from endpoints.json
vladvildanov Dec 18, 2024
76e9dea
Updated package names and ENV variables
vladvildanov Dec 18, 2024
b697e27
Moved SSL certificates code into context of class
vladvildanov Dec 19, 2024
c24ab17
Fixed fixtures for async
vladvildanov Dec 19, 2024
68ebdee
Fixed test
vladvildanov Dec 19, 2024
98fa92f
Added better endpoitns handling
vladvildanov Dec 20, 2024
e84d77a
Changed variable names
vladvildanov Dec 20, 2024
4ccd380
Added logging
vladvildanov Dec 20, 2024
6e7ad70
Fixed broken tests
vladvildanov Dec 20, 2024
a9c200c
Added TODO for SSL tests
vladvildanov Dec 20, 2024
4527bf0
Added error propagation to main thread
vladvildanov Dec 20, 2024
ac1164e
Added single connection lock
vladvildanov Dec 20, 2024
96aeb68
Codestyle fixes
vladvildanov Dec 20, 2024
9cada36
Added missing methods
vladvildanov Dec 20, 2024
92356bb
Removed wrong annotation
vladvildanov Dec 20, 2024
bd89ff8
Fixed tests
vladvildanov Dec 20, 2024
fcfdcb8
Codestyle fix
vladvildanov Dec 20, 2024
063f0d5
Updated EventListener instantiation inside of class
vladvildanov Dec 20, 2024
b15358b
Fixed variable name
vladvildanov Dec 20, 2024
e691162
Fixed variable names
vladvildanov Dec 20, 2024
ce1e10c
Fixed variable name
vladvildanov Dec 20, 2024
5de68a6
Added EventException
vladvildanov Dec 20, 2024
2851a7c
Codestyle fix
vladvildanov Dec 20, 2024
87c4e7e
Removed redundant code
vladvildanov Dec 20, 2024
d890193
Codestyle fix
vladvildanov Dec 20, 2024
04f3511
Updated test case
vladvildanov Dec 20, 2024
67f1d13
Fixed tests
vladvildanov Dec 20, 2024
c3d099d
Fixed test
vladvildanov Dec 20, 2024
a7233b0
Removed dependency
vladvildanov Dec 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.24.0
redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main
18 changes: 18 additions & 0 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
list_or_args,
)
from redis.credentials import CredentialProvider
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType, \
AfterSingleConnectionInstantiationEvent
from redis.exceptions import (
ConnectionError,
ExecAbortError,
Expand Down Expand Up @@ -233,6 +235,7 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
vladvildanov marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -320,11 +323,22 @@ def __init__(
# This arg only used if no pool is passed in
self.auto_close_connection_pool = auto_close_connection_pool
connection_pool = ConnectionPool(**kwargs)
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
[connection_pool],
ClientType.ASYNC,
credential_provider
))
else:
# If a pool is passed in, do not close it
self.auto_close_connection_pool = False
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
[connection_pool],
ClientType.ASYNC,
credential_provider
))

self.connection_pool = connection_pool
self._event_dispatcher = event_dispatcher
self.single_connection_client = single_connection_client
self.connection: Optional[Connection] = None

Expand Down Expand Up @@ -354,6 +368,10 @@ async def initialize(self: _RedisT) -> _RedisT:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")

self._event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(self.connection, ClientType.ASYNC, self._single_conn_lock)
)
return self

def set_response_callback(self, command: str, callback: ResponseCallbackT):
Expand Down
36 changes: 27 additions & 9 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import socket
import ssl
import sys
import threading
import warnings
import weakref
from abc import abstractmethod
Expand All @@ -27,6 +28,7 @@
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse

from ..event import EventDispatcher, AsyncBeforeCommandExecutionEvent
from ..utils import format_error_message

# the functionality is available in 3.11.x but has a major issue before
Expand All @@ -39,7 +41,7 @@
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.connection import DEFAULT_RESP_VERSION
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider, StreamingCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand Down Expand Up @@ -148,6 +150,7 @@ def __init__(
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher()
):
if (username or password) and credential_provider is not None:
raise DataError(
Expand Down Expand Up @@ -195,6 +198,9 @@ def __init__(
self.set_parser(parser_class)
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
self._buffer_cutoff = 6000
self._event_dispatcher = event_dispatcher
self._init_auth_args = None

try:
p = int(protocol)
except TypeError:
Expand Down Expand Up @@ -339,7 +345,9 @@ async def on_connect(self) -> None:
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
auth_args = await cred_provider.get_credentials_async()
self._init_auth_args = hash(auth_args)

# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
Expand Down Expand Up @@ -1039,6 +1047,7 @@ def __init__(
self._available_connections: List[AbstractConnection] = []
self._in_use_connections: Set[AbstractConnection] = set()
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
self._lock = asyncio.Lock()

def __repr__(self):
return (
Expand All @@ -1058,13 +1067,14 @@ def can_get_connection(self) -> bool:
)

async def get_connection(self, command_name, *keys, **options):
"""Get a connected connection from the pool"""
connection = self.get_available_connection()
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
async with self._lock:
"""Get a connected connection from the pool"""
connection = self.get_available_connection()
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise

return connection

Expand Down Expand Up @@ -1147,6 +1157,14 @@ def set_retry(self, retry: "Retry") -> None:
for conn in self._in_use_connections:
conn.retry = retry

async def re_auth_callback(self, token):
async with self._lock:
for conn in self._available_connections:
await conn.send_command(
'AUTH', token.try_get('oid'), token.get_value()
)
await conn.read_response()


class BlockingConnectionPool(ConnectionPool):
"""
Expand Down
18 changes: 18 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import copy
import re
import threading
Expand Down Expand Up @@ -27,6 +28,8 @@
UnixDomainSocketConnection,
)
from redis.credentials import CredentialProvider
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType, \
AfterSingleConnectionInstantiationEvent
from redis.exceptions import (
ConnectionError,
ExecAbortError,
Expand Down Expand Up @@ -213,6 +216,7 @@ def __init__(
protocol: Optional[int] = 2,
cache: Optional[CacheInterface] = None,
cache_config: Optional[CacheConfig] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
) -> None:
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -313,9 +317,19 @@ def __init__(
}
)
connection_pool = ConnectionPool(**kwargs)
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
[connection_pool],
ClientType.SYNC,
credential_provider
))
self.auto_close_connection_pool = True
else:
self.auto_close_connection_pool = False
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
[connection_pool],
ClientType.SYNC,
credential_provider
))

self.connection_pool = connection_pool

Expand All @@ -325,9 +339,13 @@ def __init__(
]:
raise RedisError("Client caching is only supported with RESP version 3")

self._connection_lock = threading.Lock()
self.connection = None
if single_connection_client:
self.connection = self.connection_pool.get_connection("_")
event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(self.connection, ClientType.SYNC, self._connection_lock)
)

self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)

Expand Down
16 changes: 16 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from redis.commands.helpers import list_or_args
from redis.connection import ConnectionPool, DefaultParser, parse_url
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
from redis.event import EventDispatcher, EventDispatcherInterface, AfterPooledConnectionsInstantiationEvent, ClientType
from redis.exceptions import (
AskError,
AuthenticationError,
Expand Down Expand Up @@ -505,6 +506,7 @@ def __init__(
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
cache: Optional[CacheInterface] = None,
cache_config: Optional[CacheConfig] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
**kwargs,
):
"""
Expand Down Expand Up @@ -646,6 +648,7 @@ def __init__(
address_remap=address_remap,
cache=cache,
cache_config=cache_config,
event_dispatcher=event_dispatcher,
**kwargs,
)

Expand Down Expand Up @@ -1332,6 +1335,7 @@ def __init__(
cache: Optional[CacheInterface] = None,
cache_config: Optional[CacheConfig] = None,
cache_factory: Optional[CacheFactoryInterface] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
**kwargs,
):
self.nodes_cache = {}
Expand All @@ -1353,6 +1357,8 @@ def __init__(
if lock is None:
lock = threading.Lock()
self._lock = lock
self._event_dispatcher = event_dispatcher
self._credential_provider = self.connection_kwargs.get("credential_provider", None)
self.initialize()

def get_node(self, host=None, port=None, node_name=None):
Expand Down Expand Up @@ -1479,11 +1485,21 @@ def create_redis_connections(self, nodes):
"""
This function will create a redis connection to all nodes in :nodes:
"""
connection_pools = []
for node in nodes:
if node.redis_connection is None:
node.redis_connection = self.create_redis_node(
host=node.host, port=node.port, **self.connection_kwargs
)
connection_pools.append(node.redis_connection.connection_pool)

self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
connection_pools,
ClientType.SYNC,
self._credential_provider
)
)

def create_redis_node(self, host, port, **kwargs):
if self.from_url:
Expand Down
17 changes: 16 additions & 1 deletion redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
from .backoff import NoBackoff
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider, StreamingCredentialProvider
from .event import EventDispatcherInterface, EventDispatcher, BeforeCommandExecutionEvent
from .exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand Down Expand Up @@ -229,6 +230,7 @@
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
command_packer: Optional[Callable[[], None]] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher()
):
"""
Initialize a new Connection.
Expand Down Expand Up @@ -283,6 +285,8 @@
self.set_parser(parser_class)
self._connect_callbacks = []
self._buffer_cutoff = 6000
self._event_dispatcher = event_dispatcher
self._init_auth_args = None
try:
p = int(protocol)
except TypeError:
Expand Down Expand Up @@ -408,6 +412,8 @@
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
self._init_auth_args = hash(auth_args)
print(auth_args)
Fixed Show fixed Hide fixed

# if resp version is specified and we have auth args,
# we need to send them via HELLO
Expand Down Expand Up @@ -1318,6 +1324,7 @@
connection_kwargs.pop("cache", None)
connection_kwargs.pop("cache_config", None)


# a lock to protect the critical section in _checkpid().
# this lock is acquired when the process id changes, such as
# after a fork. during this time, multiple threads in the child
Expand Down Expand Up @@ -1517,6 +1524,14 @@
for conn in self._in_use_connections:
conn.retry = retry

def re_auth_callback(self, token):
with self._lock:
for conn in self._available_connections:
conn.send_command(
'AUTH', token.try_get('oid'), token.get_value()
)
conn.read_response()


class BlockingConnectionPool(ConnectionPool):
"""
Expand Down
36 changes: 35 additions & 1 deletion redis/credentials.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Optional, Tuple, Union
import logging
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Union, Callable, Any

logger = logging.getLogger(__name__)


class CredentialProvider:
Expand All @@ -9,6 +13,33 @@ class CredentialProvider:
def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
raise NotImplementedError("get_credentials must be implemented")

async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
logger.warning("This method is added for backward compatability. Please override it in your implementation.")
return self.get_credentials()


class StreamingCredentialProvider(CredentialProvider, ABC):
"""
Credential provider that streams credentials in the background.
"""
@abstractmethod
def on_next(self, callback: Callable[[Any], None]):
"""
Specifies the callback that should be invoked when the next credentials will be retrieved.

:param callback: Callback with
:return:
"""
pass

@abstractmethod
def on_error(self, callback: Callable[[Exception], None]):
pass

@abstractmethod
def is_streaming(self) -> bool:
pass


class UsernamePasswordCredentialProvider(CredentialProvider):
"""
Expand All @@ -24,3 +55,6 @@ def get_credentials(self):
if self.username:
return self.username, self.password
return (self.password,)

async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
return self.get_credentials()
Loading
Loading