Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
btschwertfeger committed Nov 28, 2024
1 parent 8add187 commit d51831c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
4 changes: 1 addition & 3 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,9 +913,7 @@ async def send_context(
if wait_for_close:
try:
async with asyncio_timeout_at(self.close_deadline):
self.recv_messages.cancelling = True
if self.recv_messages.paused:
self.recv_messages.resume()
self.recv_messages.prepare_close()
await asyncio.shield(self.connection_lost_waiter)
except TimeoutError:
# There's no risk to overwrite another error because
Expand Down
21 changes: 15 additions & 6 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,12 @@ def __init__( # pragma: no cover
# This flag prevents concurrent calls to get() by user code.
self.get_in_progress = False

# This flag marks a soon cancellation.
self.cancelling = False
# This flag marks a soon end of the connection.
self.closing = False

# This flag marks the end of the connection.
self.closed = False


async def get(self, decode: bool | None = None) -> Data:
"""
Read the next message.
Expand All @@ -142,7 +141,7 @@ async def get(self, decode: bool | None = None) -> Data:
:meth:`get_iter` concurrently.
"""
if self.cancelling:
if self.closing:
return
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
Expand Down Expand Up @@ -207,7 +206,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
:meth:`get_iter` concurrently.
"""
if self.cancelling:
if self.closing:
return
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
Expand Down Expand Up @@ -259,7 +258,7 @@ def put(self, frame: Frame) -> None:
EOFError: If the stream of frames has ended.
"""
if self.cancelling:
if self.closing:
return
if self.closed:
raise EOFError("stream of frames ended")
Expand Down Expand Up @@ -289,6 +288,16 @@ def maybe_resume(self) -> None:
self.paused = False
self.resume()

def prepare_close(self) -> None:
"""
Prepare to close by ensuring that no more messages will be processed.
"""
self.closing = True

# Resuming the writer to avoid deadlocks
if self.paused:
self.resume()

def close(self) -> None:
"""
End the stream of frames.
Expand Down

0 comments on commit d51831c

Please sign in to comment.