Skip to content

Commit

Permalink
Fix run_in_subprocess (#299)
Browse files Browse the repository at this point in the history
1. The `queue.Full` exception is never raised, so expecting it does not
make sense.
An exception that can happen is pickle error, but it's safer to rather
catch all exceptions.
2. When child process puts a message, the parent process does not
process what
remains in the queue. This commit adds a drain mode.
3. React to KeyboardInterrupt on the main process.
  • Loading branch information
mthrok authored Dec 26, 2024
1 parent dabc1a1 commit e1d3594
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/spdl/dataloader/_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# Message from worker to the parent
_MSG_GENERATOR_FAILED = "GENERATOR_FAILED_TO_INITIALIZE"
_MSG_ITERATION_FINISHED = "ITERATION_FINISHED"
_MSG_DATA_QUEUE_FULL = "DATA_QUEUE_FULL"
_MSG_DATA_QUEUE_FAILED = "DATA_QUEUE_FAILED"


def _execute_iterator(
Expand All @@ -63,7 +63,7 @@ def _execute_iterator(
else:
if msg == _MSG_PARENT_REQUEST_STOP:
return
raise ValueError(f"Unexpected message redeived: {msg}")
raise ValueError(f"[INTERNAL ERROR] Unexpected message received: {msg}")

try:
item = next(gen)
Expand All @@ -76,8 +76,8 @@ def _execute_iterator(

try:
data_queue.put(item)
except queue.Full:
msg_queue.put(_MSG_DATA_QUEUE_FULL)
except Exception:
msg_queue.put(_MSG_DATA_QUEUE_FAILED)
return


Expand Down Expand Up @@ -111,7 +111,11 @@ def run_in_subprocess(
"""
ctx = mp.get_context(mp_context)
msg_q = ctx.Queue()
data_q = ctx.Queue(maxsize=queue_size)
data_q: mp.Queue = ctx.Queue(maxsize=queue_size)

def _drain() -> Iterator[T]:
while not data_q.empty():
yield data_q.get_nowait()

process = ctx.Process(
target=_execute_iterator,
Expand All @@ -127,18 +131,21 @@ def run_in_subprocess(
except queue.Empty:
pass
else:
# When a message is found, the child process stopped putting data.
yield from _drain()

if msg == _MSG_ITERATION_FINISHED:
return
if msg == _MSG_GENERATOR_FAILED:
raise RuntimeError(
"The worker process quit because the generator failed."
)
if msg == _MSG_DATA_QUEUE_FULL:
if msg == _MSG_DATA_QUEUE_FAILED:
raise RuntimeError(
"The worker process quit because the data queue is full for too long."
"The worker process quit because it failed at passing the data."
)

raise ValueError(f"Unexpected message received: {msg}")
raise ValueError(f"[INTERNAL ERROR] Unexpected message received: {msg}")

try:
yield data_q.get(timeout=1)
Expand All @@ -153,12 +160,11 @@ def run_in_subprocess(
f"The worker process did not produce any data for {elapsed:.2f} seconds."
)

except Exception:
except (Exception, KeyboardInterrupt):
msg_q.put(_MSG_PARENT_REQUEST_STOP)
raise
finally:
while not data_q.empty():
data_q.get_nowait()
yield from _drain()
process.join(3)

if process.exitcode is None:
Expand Down

0 comments on commit e1d3594

Please sign in to comment.