Skip to content

Commit

Permalink
Clean up sync message assembler.
Browse files Browse the repository at this point in the history
Remove support for control frames, which isn't actually used.
  • Loading branch information
aaugustin committed Jan 21, 2024
1 parent 2865bdc commit 908c7ba
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 87 deletions.
34 changes: 18 additions & 16 deletions src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
from typing import Iterator, List, Optional, cast

from ..frames import Frame, Opcode
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
from ..typing import Data


Expand All @@ -25,8 +25,11 @@ def __init__(self) -> None:
# primitives provided by the threading and queue modules.
self.mutex = threading.Lock()

# We create a latch with two events to ensure proper interleaving of
# writing and reading messages.
# We create a latch with two events to synchronize the production of
# frames and the consumption of messages (or frames) without a buffer.
# This design requires a switch between the library thread and the user
# thread for each message; that shouldn't be a performance bottleneck.

# put() sets this event to tell get() that a message can be fetched.
self.message_complete = threading.Event()
# get() sets this event to let put() that the message was fetched.
Expand Down Expand Up @@ -72,8 +75,10 @@ def get(self, timeout: Optional[float] = None) -> Data:
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
concurrently.
TimeoutError: If a timeout is provided and elapses before a
complete message is received.
"""
with self.mutex:
Expand Down Expand Up @@ -131,7 +136,7 @@ def get_iter(self) -> Iterator[Data]:
Raises:
EOFError: If the stream of frames has ended.
RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
concurrently.
"""
Expand Down Expand Up @@ -159,11 +164,10 @@ def get_iter(self) -> Iterator[Data]:
self.get_in_progress = True

# Locking with get_in_progress ensures only one thread can get here.
yield from chunks
while True:
chunk = self.chunks_queue.get()
if chunk is None:
break
chunk: Optional[Data]
for chunk in chunks:
yield chunk
while (chunk := self.chunks_queue.get()) is not None:
yield chunk

with self.mutex:
Expand Down Expand Up @@ -205,15 +209,12 @@ def put(self, frame: Frame) -> None:
if self.put_in_progress:
raise RuntimeError("put is already running")

if frame.opcode is Opcode.TEXT:
if frame.opcode is OP_TEXT:
self.decoder = UTF8Decoder(errors="strict")
elif frame.opcode is Opcode.BINARY:
elif frame.opcode is OP_BINARY:
self.decoder = None
elif frame.opcode is Opcode.CONT:
pass
else:
# Ignore control frames.
return
assert frame.opcode is OP_CONT

data: Data
if self.decoder is not None:
Expand Down Expand Up @@ -242,6 +243,7 @@ def put(self, frame: Frame) -> None:
self.put_in_progress = True

# Release the lock to allow get() to run and eventually set the event.
# Locking with put_in_progress ensures only one coroutine can get here.
self.message_fetched.wait()

with self.mutex:
Expand Down
113 changes: 42 additions & 71 deletions tests/sync/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

from websockets.frames import OP_BINARY, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame
from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
from websockets.sync.messages import *

from ..utils import MS
Expand Down Expand Up @@ -350,76 +350,6 @@ def test_get_with_timeout_times_out(self):
with self.assertRaises(TimeoutError):
self.assembler.get(MS)

# Test control frames

def test_control_frame_before_message_is_ignored(self):
"""get ignores control frames between messages."""

def putter():
self.assembler.put(Frame(OP_PING, b""))
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))

with self.run_in_thread(putter):
message = self.assembler.get()

self.assertEqual(message, "café")

def test_control_frame_in_fragmented_message_is_ignored(self):
"""get ignores control frames within fragmented messages."""

def putter():
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
self.assembler.put(Frame(OP_PING, b""))
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
self.assembler.put(Frame(OP_PONG, b""))
self.assembler.put(Frame(OP_CONT, b"a"))

with self.run_in_thread(putter):
message = self.assembler.get()

self.assertEqual(message, b"tea")

# Test concurrency

def test_get_fails_when_get_is_running(self):
"""get cannot be called concurrently with itself."""
with self.run_in_thread(self.assembler.get):
with self.assertRaises(RuntimeError):
self.assembler.get()
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_get_fails_when_get_iter_is_running(self):
"""get cannot be called concurrently with get_iter."""
with self.run_in_thread(lambda: list(self.assembler.get_iter())):
with self.assertRaises(RuntimeError):
self.assembler.get()
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_get_iter_fails_when_get_is_running(self):
"""get_iter cannot be called concurrently with get."""
with self.run_in_thread(self.assembler.get):
with self.assertRaises(RuntimeError):
list(self.assembler.get_iter())
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_get_iter_fails_when_get_iter_is_running(self):
"""get_iter cannot be called concurrently with itself."""
with self.run_in_thread(lambda: list(self.assembler.get_iter())):
with self.assertRaises(RuntimeError):
list(self.assembler.get_iter())
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_put_fails_when_put_is_running(self):
"""put cannot be called concurrently with itself."""

def putter():
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))

with self.run_in_thread(putter):
with self.assertRaises(RuntimeError):
self.assembler.put(Frame(OP_BINARY, b"tea"))
self.assembler.get() # unblock other thread

# Test termination

def test_get_fails_when_interrupted_by_close(self):
Expand Down Expand Up @@ -477,3 +407,44 @@ def test_close_is_idempotent(self):
"""close can be called multiple times safely."""
self.assembler.close()
self.assembler.close()

# Test (non-)concurrency

def test_get_fails_when_get_is_running(self):
"""get cannot be called concurrently with itself."""
with self.run_in_thread(self.assembler.get):
with self.assertRaises(RuntimeError):
self.assembler.get()
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_get_fails_when_get_iter_is_running(self):
"""get cannot be called concurrently with get_iter."""
with self.run_in_thread(lambda: list(self.assembler.get_iter())):
with self.assertRaises(RuntimeError):
self.assembler.get()
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_get_iter_fails_when_get_is_running(self):
"""get_iter cannot be called concurrently with get."""
with self.run_in_thread(self.assembler.get):
with self.assertRaises(RuntimeError):
list(self.assembler.get_iter())
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_get_iter_fails_when_get_iter_is_running(self):
"""get_iter cannot be called concurrently with itself."""
with self.run_in_thread(lambda: list(self.assembler.get_iter())):
with self.assertRaises(RuntimeError):
list(self.assembler.get_iter())
self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread

def test_put_fails_when_put_is_running(self):
"""put cannot be called concurrently with itself."""

def putter():
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))

with self.run_in_thread(putter):
with self.assertRaises(RuntimeError):
self.assembler.put(Frame(OP_BINARY, b"tea"))
self.assembler.get() # unblock other thread

0 comments on commit 908c7ba

Please sign in to comment.