diff --git a/asgiref/local.py b/asgiref/local.py index a8b9459b..eab1dc10 100644 --- a/asgiref/local.py +++ b/asgiref/local.py @@ -1,43 +1,46 @@ -import asyncio -import contextlib import contextvars import threading -from typing import Any, Dict, Union +from typing import Any, Union class _CVar: """Storage utility for Local.""" def __init__(self) -> None: - self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar( - "asgiref.local" - ) + self._thread_lock = threading.RLock() + self._data: dict[str, contextvars.ContextVar[Any]] = {} + + def __getattr__(self, key: str) -> Any: + with self._thread_lock: + try: + var = self._data[key] + except KeyError: + raise AttributeError(f"{self!r} object has no attribute {key!r}") - def __getattr__(self, key): - storage_object = self._data.get({}) try: - return storage_object[key] - except KeyError: + return var.get() + except LookupError: raise AttributeError(f"{self!r} object has no attribute {key!r}") def __setattr__(self, key: str, value: Any) -> None: - if key == "_data": + if key in ("_data", "_thread_lock"): return super().__setattr__(key, value) - storage_object = self._data.get({}) - storage_object[key] = value - self._data.set(storage_object) + with self._thread_lock: + var = self._data.get(key) + if var is None: + self._data[key] = var = contextvars.ContextVar(key) + var.set(value) def __delattr__(self, key: str) -> None: - storage_object = self._data.get({}) - if key in storage_object: - del storage_object[key] - self._data.set(storage_object) - else: - raise AttributeError(f"{self!r} object has no attribute {key!r}") + with self._thread_lock: + if key in self._data: + del self._data[key] + else: + raise AttributeError(f"{self!r} object has no attribute {key!r}") -class Local: +def Local(thread_critical: bool = False) -> Union[threading.local, _CVar]: """Local storage for async tasks. This is a namespace object (similar to `threading.local`) where data is @@ -64,65 +67,7 @@ class Local: Unlike plain `contextvars` objects, this utility is threadsafe. """ - - def __init__(self, thread_critical: bool = False) -> None: - self._thread_critical = thread_critical - self._thread_lock = threading.RLock() - - self._storage: "Union[threading.local, _CVar]" - - if thread_critical: - # Thread-local storage - self._storage = threading.local() - else: - # Contextvar storage - self._storage = _CVar() - - @contextlib.contextmanager - def _lock_storage(self): - # Thread safe access to storage - if self._thread_critical: - try: - # this is a test for are we in a async or sync - # thread - will raise RuntimeError if there is - # no current loop - asyncio.get_running_loop() - except RuntimeError: - # We are in a sync thread, the storage is - # just the plain thread local (i.e, "global within - # this thread" - it doesn't matter where you are - # in a call stack you see the same storage) - yield self._storage - else: - # We are in an async thread - storage is still - # local to this thread, but additionally should - # behave like a context var (is only visible with - # the same async call stack) - - # Ensure context exists in the current thread - if not hasattr(self._storage, "cvar"): - self._storage.cvar = _CVar() - - # self._storage is a thread local, so the members - # can't be accessed in another thread (we don't - # need any locks) - yield self._storage.cvar - else: - # Lock for thread_critical=False as other threads - # can access the exact same storage object - with self._thread_lock: - yield self._storage - - def __getattr__(self, key): - with self._lock_storage() as storage: - return getattr(storage, key) - - def __setattr__(self, key, value): - if key in ("_local", "_storage", "_thread_critical", "_thread_lock"): - return super().__setattr__(key, value) - with self._lock_storage() as storage: - setattr(storage, key, value) - - def __delattr__(self, key): - with self._lock_storage() as storage: - delattr(storage, key) + if thread_critical: + return threading.local() + else: + return _CVar() diff --git a/tests/test_local.py b/tests/test_local.py index d50cba21..cdcbd280 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -1,6 +1,7 @@ import asyncio import gc import threading +from threading import Thread import pytest @@ -338,3 +339,39 @@ async def async_function(): # inner value was set inside a new async context, meaning that # we do not see it, as context vars don't propagate up the stack assert not hasattr(test_local_not_tc, "test_value") + + +def test_visibility_thread_asgiref() -> None: + """Check visibility with subthreads.""" + test_local = Local() + test_local.value = 0 + + def _test() -> None: + # Local() is cleared when changing thread + assert not hasattr(test_local, "value") + setattr(test_local, "value", 1) + assert test_local.value == 1 + + thread = Thread(target=_test) + thread.start() + thread.join() + + assert test_local.value == 0 + + +@pytest.mark.asyncio +async def test_visibility_task() -> None: + """Check visibility with asyncio tasks.""" + test_local = Local() + test_local.value = 0 + + async def _test() -> None: + # Local is inherited when changing task + assert test_local.value == 0 + test_local.value = 1 + assert test_local.value == 1 + + await asyncio.create_task(_test()) + + # Changes should not leak to the caller + assert test_local.value == 0