Skip to content

Commit

Permalink
Enhancement: refine the reader interface (apache#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME authored Jun 21, 2024
1 parent 11d4221 commit 854171d
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 97 deletions.
69 changes: 24 additions & 45 deletions crates/iceberg/src/arrow/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,21 @@ use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisito
use crate::expr::{BoundPredicate, BoundReference};
use crate::io::{FileIO, FileMetadata, FileRead};
use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream};
use crate::spec::{Datum, SchemaRef};
use crate::spec::{Datum, Schema};
use crate::{Error, ErrorKind};

/// Builder to create ArrowReader
pub struct ArrowReaderBuilder {
batch_size: Option<usize>,
field_ids: Vec<usize>,
file_io: FileIO,
schema: SchemaRef,
predicates: Option<BoundPredicate>,
}

impl ArrowReaderBuilder {
/// Create a new ArrowReaderBuilder
pub fn new(file_io: FileIO, schema: SchemaRef) -> Self {
pub(crate) fn new(file_io: FileIO) -> Self {
ArrowReaderBuilder {
batch_size: None,
field_ids: vec![],
file_io,
schema,
predicates: None,
}
}

Expand All @@ -75,38 +69,20 @@ impl ArrowReaderBuilder {
self
}

/// Sets the desired column projection with a list of field ids.
pub fn with_field_ids(mut self, field_ids: impl IntoIterator<Item = usize>) -> Self {
self.field_ids = field_ids.into_iter().collect();
self
}

/// Sets the predicates to apply to the scan.
pub fn with_predicates(mut self, predicates: BoundPredicate) -> Self {
self.predicates = Some(predicates);
self
}

/// Build the ArrowReader.
pub fn build(self) -> ArrowReader {
ArrowReader {
batch_size: self.batch_size,
field_ids: self.field_ids,
schema: self.schema,
file_io: self.file_io,
predicates: self.predicates,
}
}
}

/// Reads data from Parquet files
#[derive(Clone)]
pub struct ArrowReader {
batch_size: Option<usize>,
field_ids: Vec<usize>,
#[allow(dead_code)]
schema: SchemaRef,
file_io: FileIO,
predicates: Option<BoundPredicate>,
}

impl ArrowReader {
Expand All @@ -115,16 +91,16 @@ impl ArrowReader {
pub fn read(self, mut tasks: FileScanTaskStream) -> crate::Result<ArrowRecordBatchStream> {
let file_io = self.file_io.clone();

// Collect Parquet column indices from field ids
let mut collector = CollectFieldIdVisitor {
field_ids: HashSet::default(),
};
if let Some(predicates) = &self.predicates {
visit(&mut collector, predicates)?;
}

Ok(try_stream! {
while let Some(Ok(task)) = tasks.next().await {
// Collect Parquet column indices from field ids
let mut collector = CollectFieldIdVisitor {
field_ids: HashSet::default(),
};
if let Some(predicates) = task.predicate() {
visit(&mut collector, predicates)?;
}

let parquet_file = file_io
.new_input(task.data_file_path())?;
let (parquet_metadata, parquet_reader) = try_join!(parquet_file.metadata(), parquet_file.reader())?;
Expand All @@ -135,11 +111,11 @@ impl ArrowReader {

let parquet_schema = batch_stream_builder.parquet_schema();
let arrow_schema = batch_stream_builder.schema();
let projection_mask = self.get_arrow_projection_mask(parquet_schema, arrow_schema)?;
let projection_mask = self.get_arrow_projection_mask(task.project_field_ids(),task.schema(),parquet_schema, arrow_schema)?;
batch_stream_builder = batch_stream_builder.with_projection(projection_mask);

let parquet_schema = batch_stream_builder.parquet_schema();
let row_filter = self.get_row_filter(parquet_schema, &collector)?;
let row_filter = self.get_row_filter(task.predicate(),parquet_schema, &collector)?;

if let Some(row_filter) = row_filter {
batch_stream_builder = batch_stream_builder.with_row_filter(row_filter);
Expand All @@ -161,10 +137,12 @@ impl ArrowReader {

fn get_arrow_projection_mask(
&self,
field_ids: &[i32],
iceberg_schema_of_task: &Schema,
parquet_schema: &SchemaDescriptor,
arrow_schema: &ArrowSchemaRef,
) -> crate::Result<ProjectionMask> {
if self.field_ids.is_empty() {
if field_ids.is_empty() {
Ok(ProjectionMask::all())
} else {
// Build the map between field id and column index in Parquet schema.
Expand All @@ -184,11 +162,11 @@ impl ArrowReader {
}
let field_id = field_id.unwrap();

if !self.field_ids.contains(&(field_id as usize)) {
if !field_ids.contains(&field_id) {
return false;
}

let iceberg_field = self.schema.field_by_id(field_id);
let iceberg_field = iceberg_schema_of_task.field_by_id(field_id);
let parquet_iceberg_field = iceberg_schema.field_by_id(field_id);

if iceberg_field.is_none() || parquet_iceberg_field.is_none() {
Expand All @@ -203,19 +181,19 @@ impl ArrowReader {
true
});

if column_map.len() != self.field_ids.len() {
if column_map.len() != field_ids.len() {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Parquet schema {} and Iceberg schema {} do not match.",
iceberg_schema, self.schema
iceberg_schema, iceberg_schema_of_task
),
));
}

let mut indices = vec![];
for field_id in &self.field_ids {
if let Some(col_idx) = column_map.get(&(*field_id as i32)) {
for field_id in field_ids {
if let Some(col_idx) = column_map.get(field_id) {
indices.push(*col_idx);
} else {
return Err(Error::new(
Expand All @@ -230,10 +208,11 @@ impl ArrowReader {

fn get_row_filter(
&self,
predicates: Option<&BoundPredicate>,
parquet_schema: &SchemaDescriptor,
collector: &CollectFieldIdVisitor,
) -> Result<Option<RowFilter>> {
if let Some(predicates) = &self.predicates {
if let Some(predicates) = predicates {
let field_id_map = build_field_id_map(parquet_schema)?;

// Collect Parquet column indices from field ids.
Expand Down
Loading

0 comments on commit 854171d

Please sign in to comment.