diff --git a/repro.py b/repro.py new file mode 100644 index 00000000..8fda0e9c --- /dev/null +++ b/repro.py @@ -0,0 +1,84 @@ +import asyncio +import logging +import json +from websockets.asyncio.client import connect + +# import debugpy + +# # Allow VS Code to attach +# debugpy.listen(("0.0.0.0", 5678)) # Use the port you've specified +# print("Waiting for debugger to attach...") +# debugpy.wait_for_client() + +logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s %(module)s:%(lineno)d %(levelname)8s | %(message)s", + datefmt="%Y/%m/%d %H:%M:%S", + level=logging.DEBUG, +) + +class MyClient: + + def __init__(self): + self.keep_alive = True + + async def run(self): + + async with connect( + f"wss://ws.kraken.com/v2", + ping_interval=30, + # max_queue=None, # having this enabled doesn't cause problems + ) as socket: + await socket.send( + json.dumps( + { + "method": "subscribe", + "params": { + "channel": "book", + "symbol": [ + "BTC/USD", + "DOT/USD", + "ETH/USD", + "MATIC/USD", + "BTC/EUR", + "DOT/EUR", + "ETH/EUR", + "XLM/USD", + "XLM/EUR", + ], + "depth": 100 + }, + } + ) + ) + + while self.keep_alive: + try: + _message = await asyncio.wait_for(socket.recv(), timeout=10) + except TimeoutError: + pass + except asyncio.CancelledError: + self.keep_alive = False + else: + try: + message = json.loads(_message) + except ValueError: + pass + + async def __aenter__(self): + self.task: asyncio.Task = asyncio.create_task(self.run()) + return self + + async def __aexit__(self, *args, **kwargs): + self.keep_alive = False + if hasattr(self, "task") and not self.task.done(): + await self.task + + +async def main(): + async with MyClient(): + await asyncio.sleep(3) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index e5c350fe..c947f5a4 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -913,6 +913,7 @@ async def send_context( if wait_for_close: try: async with asyncio_timeout_at(self.close_deadline): + self.recv_messages.prepare_close() await asyncio.shield(self.connection_lost_waiter) except TimeoutError: # There's no risk to overwrite another error because diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e6d1d31c..936adf63 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -113,6 +113,9 @@ def __init__( # pragma: no cover # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False + # This flag marks a soon end of the connection. + self.closing = False + # This flag marks the end of the connection. self.closed = False @@ -255,7 +258,9 @@ def put(self, frame: Frame) -> None: raise EOFError("stream of frames ended") self.frames.put(frame) - self.maybe_pause() + + if not self.closing: + self.maybe_pause() def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" @@ -279,11 +284,22 @@ def maybe_resume(self) -> None: self.paused = False self.resume() + def prepare_close(self) -> None: + """ + Prepare to close by ensuring that no more messages will be processed. + """ + self.closing = True + + # Resuming the writer to avoid deadlocks + if self.paused: + self.paused = False + self.resume() + def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 98490797..7d21a4fa 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -291,7 +291,7 @@ def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 5a0b61bf..95f4e56a 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -806,6 +806,27 @@ async def test_close_preserves_queued_messages(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) + async def test_close_preserves_queued_messages_gt_max_queue(self): + """ + close preserves messages buffered in the assembler, even if they + exceed the default buffer size. + """ + + for _ in range(100): + await self.remote_connection.send("😀") + + await self.connection.close() + + for _ in range(100): + self.assertEqual(await self.connection.recv(), "😀") + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close()