Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(connector): file scan use correct path #19793

Merged
merged 4 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/batch/executors/src/executor/s3_file_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ use futures_async_stream::try_stream;
use futures_util::stream::StreamExt;
use risingwave_common::array::DataChunk;
use risingwave_common::catalog::{Field, Schema};
use risingwave_connector::source::iceberg::{new_s3_operator, read_parquet_file};
use risingwave_connector::source::iceberg::{
extract_bucket_and_file_name, new_s3_operator, read_parquet_file,
};
use risingwave_pb::batch_plan::file_scan_node;
use risingwave_pb::batch_plan::file_scan_node::StorageType;
use risingwave_pb::batch_plan::plan_node::NodeBody;
Expand Down Expand Up @@ -82,13 +84,15 @@ impl S3FileScanExecutor {
async fn do_execute(self: Box<Self>) {
assert_eq!(self.file_format, FileFormat::Parquet);
for file in self.file_location {
let (bucket, file_name) = extract_bucket_and_file_name(&file)?;
let op = new_s3_operator(
self.s3_region.clone(),
self.s3_access_key.clone(),
self.s3_secret_key.clone(),
file.clone(),
bucket.clone(),
)?;
let chunk_stream = read_parquet_file(op, file, None, None, self.batch_size, 0).await?;
let chunk_stream =
read_parquet_file(op, file_name, None, None, self.batch_size, 0).await?;
#[for_await]
for stream_chunk in chunk_stream {
let stream_chunk = stream_chunk?;
Expand Down
115 changes: 58 additions & 57 deletions src/connector/src/source/iceberg/parquet_file_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,9 @@ pub fn new_s3_operator(
s3_region: String,
s3_access_key: String,
s3_secret_key: String,
location: String,
bucket: String,
) -> ConnectorResult<Operator> {
// Create s3 builder.
let bucket = extract_bucket(&location);
let mut builder = S3::default().bucket(&bucket).region(&s3_region);
builder = builder.secret_access_key(&s3_access_key);
builder = builder.secret_access_key(&s3_secret_key);
Expand All @@ -120,8 +119,6 @@ pub fn new_s3_operator(
bucket, s3_region
));

builder = builder.disable_config_load();

let op: Operator = Operator::new(builder)?
.layer(LoggingLayer::default())
.layer(RetryLayer::default())
Expand All @@ -130,13 +127,20 @@ pub fn new_s3_operator(
Ok(op)
}

fn extract_bucket(location: &str) -> String {
let prefix = "s3://";
let start = prefix.len();
let end = location[start..]
.find('/')
.unwrap_or(location.len() - start);
location[start..start + end].to_string()
pub fn extract_bucket_and_file_name(location: &str) -> ConnectorResult<(String, String)> {
let url = Url::parse(location)?;
let bucket = url
.host_str()
.ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
format!("Invalid s3 url: {}, missing bucket", location),
)
})?
.to_owned();
let prefix = format!("s3://{}/", bucket);
let file_name = location[prefix.len()..].to_string();
Ok((bucket, file_name))
}

pub async fn list_s3_directory(
Expand All @@ -145,27 +149,24 @@ pub async fn list_s3_directory(
s3_secret_key: String,
dir: String,
) -> Result<Vec<String>, anyhow::Error> {
let url = Url::parse(&dir)?;
let bucket = url.host_str().ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
format!("Invalid s3 url: {}, missing bucket", dir),
)
})?;

let (bucket, file_name) = extract_bucket_and_file_name(&dir)?;
let prefix = format!("s3://{}/", bucket);
if dir.starts_with(&prefix) {
let mut builder = S3::default();
builder = builder
.region(&s3_region)
.access_key_id(&s3_access_key)
.secret_access_key(&s3_secret_key)
.bucket(bucket);
.bucket(&bucket);
builder = builder.endpoint(&format!(
"https://{}.s3.{}.amazonaws.com",
bucket, s3_region
));
let op = Operator::new(builder)?
.layer(RetryLayer::default())
.finish();

op.list(&dir[prefix.len()..])
op.list(&file_name)
.await
.map_err(|e| anyhow!(e))
.map(|list| {
Expand Down Expand Up @@ -197,44 +198,39 @@ pub async fn list_s3_directory(
/// Parquet file schema that match the requested schema. If an error occurs during processing,
/// it returns an appropriate error.
pub fn extract_valid_column_indices(
columns: Option<Vec<Column>>,
rw_columns: Vec<Column>,
metadata: &FileMetaData,
) -> ConnectorResult<Vec<usize>> {
match columns {
Some(rw_columns) => {
let parquet_column_names = metadata
.schema_descr()
.columns()
.iter()
.map(|c| c.name())
.collect_vec();
let parquet_column_names = metadata
.schema_descr()
.columns()
.iter()
.map(|c| c.name())
.collect_vec();

let converted_arrow_schema =
parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata())
.map_err(anyhow::Error::from)?;
let converted_arrow_schema =
parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata())
.map_err(anyhow::Error::from)?;

let valid_column_indices: Vec<usize> = rw_columns
.iter()
.filter_map(|column| {
parquet_column_names
.iter()
.position(|&name| name == column.name)
.and_then(|pos| {
let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field(pos).data_type();
let rw_data_type: &risingwave_common::types::DataType = &column.data_type;
let valid_column_indices: Vec<usize> = rw_columns
.iter()
.filter_map(|column| {
parquet_column_names
.iter()
.position(|&name| name == column.name)
.and_then(|pos| {
let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field(pos).data_type();
let rw_data_type: &risingwave_common::types::DataType = &column.data_type;

if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) {
Some(pos)
} else {
None
}
})
})
.collect();
Ok(valid_column_indices)
}
None => Ok(vec![]),
}
if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) {
Some(pos)
} else {
None
}
})
})
.collect();
Comment on lines +216 to +232
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fmt?

Copy link
Contributor Author

@wcy-fdu wcy-fdu Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange, how did these lines escape cargo fmt escape. I think it's already formatted.

Ok(valid_column_indices)
}

/// Reads a specified Parquet file and converts its content into a stream of chunks.
Expand All @@ -258,8 +254,14 @@ pub async fn read_parquet_file(
let parquet_metadata = reader.get_metadata().await.map_err(anyhow::Error::from)?;

let file_metadata = parquet_metadata.file_metadata();
let column_indices = extract_valid_column_indices(rw_columns, file_metadata)?;
let projection_mask = ProjectionMask::leaves(file_metadata.schema_descr(), column_indices);
let projection_mask = match rw_columns {
Some(columns) => {
let column_indices = extract_valid_column_indices(columns, file_metadata)?;
ProjectionMask::leaves(file_metadata.schema_descr(), column_indices)
}
None => ProjectionMask::all(),
};

// For the Parquet format, we directly convert from a record batch to a stream chunk.
// Therefore, the offset of the Parquet file represents the current position in terms of the number of rows read from the file.
let record_batch_stream = ParquetRecordBatchStreamBuilder::new(reader)
Expand Down Expand Up @@ -289,7 +291,6 @@ pub async fn read_parquet_file(
})
.collect(),
};

let parquet_parser = ParquetParser::new(columns, file_name, offset)?;
let msg_stream: Pin<
Box<dyn Stream<Item = Result<StreamChunk, crate::error::ConnectorError>> + Send>,
Expand Down
22 changes: 9 additions & 13 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use mysql_async::prelude::*;
use risingwave_common::array::arrow::IcebergArrowConvert;
use risingwave_common::types::{DataType, ScalarImpl, StructType};
use risingwave_connector::source::iceberg::{
get_parquet_fields, list_s3_directory, new_s3_operator,
extract_bucket_and_file_name, get_parquet_fields, list_s3_directory, new_s3_operator,
};
pub use risingwave_pb::expr::table_function::PbType as TableFunctionType;
use risingwave_pb::expr::PbTableFunction;
Expand Down Expand Up @@ -177,23 +177,19 @@ impl TableFunction {

let schema = tokio::task::block_in_place(|| {
FRONTEND_RUNTIME.block_on(async {
let location = match files.as_ref() {
Some(files) => files[0].clone(),
None => eval_args[5].clone(),
};
let (bucket, file_name) = extract_bucket_and_file_name(&location)?;
let op = new_s3_operator(
eval_args[2].clone(),
eval_args[3].clone(),
eval_args[4].clone(),
match files.as_ref() {
Some(files) => files[0].clone(),
None => eval_args[5].clone(),
},
bucket.clone(),
)?;
let fields = get_parquet_fields(
op,
match files.as_ref() {
Some(files) => files[0].clone(),
None => eval_args[5].clone(),
},
)
.await?;

let fields = get_parquet_fields(op, file_name).await?;

let mut rw_types = vec![];
for field in &fields {
Expand Down
Loading