Skip to content

Commit

Permalink
Make all the DataLoader attributes private (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Dec 26, 2024
1 parent e1d3594 commit 6e11733
Showing 1 changed file with 21 additions and 23 deletions.
44 changes: 21 additions & 23 deletions src/spdl/dataloader/_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ class DataLoader(Generic[Source, Output]):
passed to the aggregator in order of completion. If 'input', then they are passed
to the aggregator in the order of the source input.
:ivar src: The source object provided in the constructor.
Exapmles:
>>> import spdl.io
>>> from spdl.io import ImageFrames
Expand Down Expand Up @@ -194,37 +192,37 @@ def __init__(
timeout: float | None = None,
output_order: str = "completion",
) -> None:
self.src = src
self.preprocessor = preprocessor
self.aggregator = aggregator
self._src = src
self._preprocessor = preprocessor
self._aggregator = aggregator

self.batch_size = batch_size
self.drop_last = drop_last
self.buffer_size = buffer_size
self.num_threads = num_threads
self.timeout = timeout
self.output_order = output_order
self._batch_size = batch_size
self._drop_last = drop_last
self._buffer_size = buffer_size
self._num_threads = num_threads
self._timeout = timeout
self._output_order = output_order

def _get_pipeline(self) -> Pipeline:
builder = PipelineBuilder()
builder.add_source(self.src)
builder.add_source(self._src)

if self.preprocessor:
if self._preprocessor:
builder.pipe(
self.preprocessor,
concurrency=self.num_threads,
output_order=self.output_order,
self._preprocessor,
concurrency=self._num_threads,
output_order=self._output_order,
)

if self.batch_size:
builder.aggregate(self.batch_size, drop_last=self.drop_last)
if self._batch_size:
builder.aggregate(self._batch_size, drop_last=self._drop_last)

if self.aggregator:
builder.pipe(self.aggregator)
if self._aggregator:
builder.pipe(self._aggregator)

builder.add_sink(self.buffer_size)
builder.add_sink(self._buffer_size)

return builder.build(num_threads=self.num_threads)
return builder.build(num_threads=self._num_threads)

def __iter__(self) -> Iterable[Output]:
"""Run the data loading pipeline in background.
Expand All @@ -235,5 +233,5 @@ def __iter__(self) -> Iterable[Output]:
pipeline = self._get_pipeline()

with pipeline.auto_stop():
for item in pipeline.get_iterator(timeout=self.timeout):
for item in pipeline.get_iterator(timeout=self._timeout):
yield item

0 comments on commit 6e11733

Please sign in to comment.