Skip to content

Commit

Permalink
feat(parquet): Add next_row_group API for ParquetRecordBatchStream (#…
Browse files Browse the repository at this point in the history
…6907)

* feat(parquet): Add next_row_group API for ParquetRecordBatchStream

Signed-off-by: Xuanwo <[email protected]>

* chore: Returning error instead of using unreachable

Signed-off-by: Xuanwo <[email protected]>

---------

Signed-off-by: Xuanwo <[email protected]>
  • Loading branch information
Xuanwo authored Dec 24, 2024
1 parent 2c84f24 commit 10cf03c
Showing 1 changed file with 132 additions and 0 deletions.
132 changes: 132 additions & 0 deletions parquet/src/arrow/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ impl<T> std::fmt::Debug for StreamState<T> {

/// An asynchronous [`Stream`](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) of [`RecordBatch`]
/// for a parquet file that can be constructed using [`ParquetRecordBatchStreamBuilder`].
///
/// `ParquetRecordBatchStream` also provides [`ParquetRecordBatchStream::next_row_group`] for fetching row groups,
/// allowing users to decode record batches separately from I/O.
pub struct ParquetRecordBatchStream<T> {
metadata: Arc<ParquetMetaData>,

Expand Down Expand Up @@ -654,6 +657,70 @@ impl<T> ParquetRecordBatchStream<T> {
}
}

impl<T> ParquetRecordBatchStream<T>
where
T: AsyncFileReader + Unpin + Send + 'static,
{
/// Fetches the next row group from the stream.
///
/// Users can continue to call this function to get row groups and decode them concurrently.
///
/// ## Notes
///
/// ParquetRecordBatchStream should be used either as a `Stream` or with `next_row_group`; they should not be used simultaneously.
///
/// ## Returns
///
/// - `Ok(None)` if the stream has ended.
/// - `Err(error)` if the stream has errored. All subsequent calls will return `Ok(None)`.
/// - `Ok(Some(reader))` which holds all the data for the row group.
pub async fn next_row_group(&mut self) -> Result<Option<ParquetRecordBatchReader>> {
loop {
match &mut self.state {
StreamState::Decoding(_) | StreamState::Reading(_) => {
return Err(ParquetError::General(
"Cannot combine the use of next_row_group with the Stream API".to_string(),
))
}
StreamState::Init => {
let row_group_idx = match self.row_groups.pop_front() {
Some(idx) => idx,
None => return Ok(None),
};

let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize;

let selection = self.selection.as_mut().map(|s| s.split_off(row_count));

let reader_factory = self.reader.take().expect("lost reader");

let (reader_factory, maybe_reader) = reader_factory
.read_row_group(
row_group_idx,
selection,
self.projection.clone(),
self.batch_size,
)
.await
.map_err(|err| {
self.state = StreamState::Error;
err
})?;
self.reader = Some(reader_factory);

if let Some(reader) = maybe_reader {
return Ok(Some(reader));
} else {
// All rows skipped, read next row group
continue;
}
}
StreamState::Error => return Ok(None), // Ends the stream as error happens.
}
}
}
}

impl<T> Stream for ParquetRecordBatchStream<T>
where
T: AsyncFileReader + Unpin + Send + 'static,
Expand Down Expand Up @@ -1020,6 +1087,71 @@ mod tests {
);
}

#[tokio::test]
async fn test_async_reader_with_next_row_group() {
let testdata = arrow::util::test_util::parquet_test_data();
let path = format!("{testdata}/alltypes_plain.parquet");
let data = Bytes::from(std::fs::read(path).unwrap());

let metadata = ParquetMetaDataReader::new()
.parse_and_finish(&data)
.unwrap();
let metadata = Arc::new(metadata);

assert_eq!(metadata.num_row_groups(), 1);

let async_reader = TestReader {
data: data.clone(),
metadata: metadata.clone(),
requests: Default::default(),
};

let requests = async_reader.requests.clone();
let builder = ParquetRecordBatchStreamBuilder::new(async_reader)
.await
.unwrap();

let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]);
let mut stream = builder
.with_projection(mask.clone())
.with_batch_size(1024)
.build()
.unwrap();

let mut readers = vec![];
while let Some(reader) = stream.next_row_group().await.unwrap() {
readers.push(reader);
}

let async_batches: Vec<_> = readers
.into_iter()
.flat_map(|r| r.map(|v| v.unwrap()).collect::<Vec<_>>())
.collect();

let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data)
.unwrap()
.with_projection(mask)
.with_batch_size(104)
.build()
.unwrap()
.collect::<ArrowResult<Vec<_>>>()
.unwrap();

assert_eq!(async_batches, sync_batches);

let requests = requests.lock().unwrap();
let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range();
let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range();

assert_eq!(
&requests[..],
&[
offset_1 as usize..(offset_1 + length_1) as usize,
offset_2 as usize..(offset_2 + length_2) as usize
]
);
}

#[tokio::test]
async fn test_async_reader_with_index() {
let testdata = arrow::util::test_util::parquet_test_data();
Expand Down

0 comments on commit 10cf03c

Please sign in to comment.