From ab6c0e3872aa4fdbd3cb6d172ecdeedcbb2bbd1d Mon Sep 17 00:00:00 2001 From: Samantha Hughes Date: Sat, 25 Sep 2021 12:29:34 -0700 Subject: [PATCH] typing for sync/async decorators --- asgiref/sync.py | 62 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index de9204eb..e4b5d2c9 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -7,7 +7,9 @@ import warnings import weakref from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, overload +from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar, overload + +from typing_extensions import ParamSpec from .compatibility import current_task, get_running_loop from .current_thread_executor import CurrentThreadExecutor @@ -98,7 +100,11 @@ async def __aexit__(self, exc, value, tb): pass -class AsyncToSync: +a_cls_params = ParamSpec("a_cls_params") +a_cls_return = TypeVar("a_cls_return") + + +class AsyncToSync(Generic[a_cls_params, a_cls_return]): """ Utility class which turns an awaitable that only works on the thread with the event loop into a synchronous callable that works in a subthread. @@ -118,7 +124,11 @@ class AsyncToSync: # Local, not a threadlocal, so that tasks can work out what their parent used. executors = Local() - def __init__(self, awaitable, force_new_loop=False): + def __init__( + self, + awaitable: Callable[a_cls_params, Awaitable[a_cls_return]], + force_new_loop=False, + ): if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable): # Python does not have very reliable detection of async functions # (lots of false negatives) so this is just a warning. @@ -149,7 +159,9 @@ def __init__(self, awaitable, force_new_loop=False): else: self.main_event_loop = None - def __call__(self, *args, **kwargs): + def __call__( + self, *args: a_cls_params.args, **kwargs: a_cls_params.kwargs + ) -> a_cls_return: # You can't call AsyncToSync from a thread with a running event loop try: event_loop = get_running_loop() @@ -170,7 +182,7 @@ def __call__(self, *args, **kwargs): context = None # Make a future for the return information - call_result = Future() + call_result: Future[a_cls_return] = Future() # Get the source thread source_thread = threading.current_thread() # Make a CurrentThreadExecutor we'll use to idle in this thread - we @@ -269,7 +281,13 @@ def __get__(self, parent, objtype): return functools.update_wrapper(func, self.awaitable) async def main_wrap( - self, args, kwargs, call_result, source_thread, exc_info, context + self, + args, + kwargs, + call_result: a_cls_return, + source_thread, + exc_info, + context, ): """ Wraps the awaitable with something that puts the result into the @@ -301,7 +319,11 @@ async def main_wrap( context[0] = contextvars.copy_context() -class SyncToAsync: +s_cls_params = ParamSpec("s_cls_params") +s_cls_return = TypeVar("s_cls_return") + + +class SyncToAsync(Generic[s_cls_params, s_cls_return]): """ Utility class which turns a synchronous callable into an awaitable that runs in a threadpool. It also sets a threadlocal inside the thread so @@ -367,7 +389,7 @@ class SyncToAsync: def __init__( self, - func: Callable[..., Any], + func: Callable[s_cls_params, s_cls_return], thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, ) -> None: @@ -385,7 +407,9 @@ def __init__( except AttributeError: pass - async def __call__(self, *args, **kwargs): + async def __call__( + self, *args: s_cls_params.args, **kwargs: s_cls_params.kwargs + ) -> s_cls_return: loop = get_running_loop() # Work out what thread to run the code in @@ -459,7 +483,15 @@ def __get__(self, parent, objtype): """ return functools.partial(self.__call__, parent) - def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs): + def thread_handler( + self, + loop, + source_task, + exc_info, + func: Callable[s_cls_params, s_cls_return], + *args: s_cls_params.args, + **kwargs: s_cls_params.kwargs + ): """ Wraps the sync application with exception handling. """ @@ -509,21 +541,25 @@ def get_current_task(): async_to_sync = AsyncToSync +s_params = ParamSpec("s_params") +s_return = TypeVar("s_return") + + @overload def sync_to_async( func: None = None, thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> Callable[[Callable[..., Any]], SyncToAsync]: +) -> Callable[[Callable[s_params, s_return]], SyncToAsync[s_params, s_return]]: ... @overload def sync_to_async( - func: Callable[..., Any], + func: Callable[s_params, s_return], thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> SyncToAsync: +) -> SyncToAsync[s_params, s_return]: ...