diff --git a/Cargo.lock b/Cargo.lock index f1d53fe8d51cb..d400abdc80445 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9586,6 +9586,7 @@ dependencies = [ "itertools 0.12.1", "jni", "jsonschema-transpiler", + "jsonwebtoken 9.3.0", "madsim-rdkafka", "madsim-tokio", "madsim-tonic", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 7b43c39961ed7..0a1ba523a9ff6 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -72,6 +72,7 @@ icelake = { workspace = true } indexmap = { version = "1.9.3", features = ["serde"] } itertools = { workspace = true } jni = { version = "0.21.1", features = ["invocation"] } +jsonwebtoken = "9.2.0" jst = { package = 'jsonschema-transpiler', git = "https://github.com/mozilla/jsonschema-transpiler", rev = "c1a89d720d118843d8bcca51084deb0ed223e4b4" } maplit = "1.0.2" moka = { version = "0.12", features = ["future"] } diff --git a/src/connector/src/sink/mod.rs b/src/connector/src/sink/mod.rs index 9facd476bf34c..c430b4303f1e9 100644 --- a/src/connector/src/sink/mod.rs +++ b/src/connector/src/sink/mod.rs @@ -33,6 +33,8 @@ pub mod nats; pub mod pulsar; pub mod redis; pub mod remote; +pub mod snowflake; +pub mod snowflake_connector; pub mod starrocks; pub mod test_sink; pub mod trivial; @@ -91,6 +93,7 @@ macro_rules! for_all_sinks { { HttpJava, $crate::sink::remote::HttpJavaSink }, { Doris, $crate::sink::doris::DorisSink }, { Starrocks, $crate::sink::starrocks::StarrocksSink }, + { Snowflake, $crate::sink::snowflake::SnowflakeSink }, { DeltaLake, $crate::sink::deltalake::DeltaLakeSink }, { BigQuery, $crate::sink::big_query::BigQuerySink }, { Test, $crate::sink::test_sink::TestSink }, @@ -538,6 +541,8 @@ pub enum SinkError { ), #[error("Starrocks error: {0}")] Starrocks(String), + #[error("Snowflake error: {0}")] + Snowflake(String), #[error("Pulsar error: {0}")] Pulsar( #[source] diff --git a/src/connector/src/sink/snowflake.rs b/src/connector/src/sink/snowflake.rs new file mode 100644 index 0000000000000..ba0973a0b0145 --- /dev/null +++ b/src/connector/src/sink/snowflake.rs @@ -0,0 +1,337 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::anyhow; +use async_trait::async_trait; +use risingwave_common::array::{Op, StreamChunk}; +use risingwave_common::buffer::Bitmap; +use risingwave_common::catalog::Schema; +use serde::Deserialize; +use serde_json::Value; +use serde_with::serde_as; +use uuid::Uuid; +use with_options::WithOptions; + +use super::encoder::{ + JsonEncoder, RowEncoder, TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode, +}; +use super::snowflake_connector::{SnowflakeHttpClient, SnowflakeS3Client}; +use super::writer::LogSinkerOf; +use super::{SinkError, SinkParam}; +use crate::sink::writer::SinkWriterExt; +use crate::sink::{DummySinkCommitCoordinator, Result, Sink, SinkWriter, SinkWriterParam}; + +pub const SNOWFLAKE_SINK: &str = "snowflake"; + +#[derive(Deserialize, Debug, Clone, WithOptions)] +pub struct SnowflakeCommon { + /// The snowflake database used for sinking + #[serde(rename = "snowflake.database")] + pub database: String, + + /// The corresponding schema where sink table exists + #[serde(rename = "snowflake.schema")] + pub schema: String, + + /// The created pipe object, will be used as `insertFiles` target + #[serde(rename = "snowflake.pipe")] + pub pipe: String, + + /// The unique, snowflake provided `account_identifier` + /// NOTE: please use the form `-` + /// For detailed guidance, reference: + #[serde(rename = "snowflake.account_identifier")] + pub account_identifier: String, + + /// The user that owns the table to be sinked + /// NOTE: the user should've been granted corresponding *role* + /// reference: + #[serde(rename = "snowflake.user")] + pub user: String, + + /// The public key fingerprint used when generating custom `jwt_token` + /// reference: + #[serde(rename = "snowflake.rsa_public_key_fp")] + pub rsa_public_key_fp: String, + + /// The rsa pem key *without* encryption + #[serde(rename = "snowflake.private_key")] + pub private_key: String, + + /// The s3 bucket where intermediate sink files will be stored + #[serde(rename = "snowflake.s3_bucket")] + pub s3_bucket: String, + + /// The optional s3 path to be specified + /// the actual file location would be `:///` + /// if this field is specified by user(s) + /// otherwise it would be `://` + #[serde(rename = "snowflake.s3_path")] + pub s3_path: Option, + + /// s3 credentials + #[serde(rename = "snowflake.aws_access_key_id")] + pub aws_access_key_id: String, + + /// s3 credentials + #[serde(rename = "snowflake.aws_secret_access_key")] + pub aws_secret_access_key: String, + + /// The s3 region, e.g., us-east-2 + #[serde(rename = "snowflake.aws_region")] + pub aws_region: String, + + /// The configurable max row(s) to batch, + /// which should be *explicitly* specified by user(s) + #[serde(rename = "snowflake.max_batch_row_num")] + pub max_batch_row_num: String, +} + +#[serde_as] +#[derive(Clone, Debug, Deserialize, WithOptions)] +pub struct SnowflakeConfig { + #[serde(flatten)] + pub common: SnowflakeCommon, +} + +impl SnowflakeConfig { + pub fn from_hashmap(properties: HashMap) -> Result { + let config = + serde_json::from_value::(serde_json::to_value(properties).unwrap()) + .map_err(|e| SinkError::Config(anyhow!(e)))?; + Ok(config) + } +} + +#[derive(Debug)] +pub struct SnowflakeSink { + pub config: SnowflakeConfig, + schema: Schema, + pk_indices: Vec, + is_append_only: bool, +} + +impl Sink for SnowflakeSink { + type Coordinator = DummySinkCommitCoordinator; + type LogSinker = LogSinkerOf; + + const SINK_NAME: &'static str = SNOWFLAKE_SINK; + + async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result { + Ok(SnowflakeSinkWriter::new( + self.config.clone(), + self.schema.clone(), + self.pk_indices.clone(), + self.is_append_only, + ) + .await + .into_log_sinker(writer_param.sink_metrics)) + } + + async fn validate(&self) -> Result<()> { + if !self.is_append_only { + return Err(SinkError::Config( + anyhow!("SnowflakeSink only supports append-only mode at present, please change the query to append-only, or use `force_append_only = 'true'`") + )); + } + Ok(()) + } +} + +impl TryFrom for SnowflakeSink { + type Error = SinkError; + + fn try_from(param: SinkParam) -> std::result::Result { + let schema = param.schema(); + let config = SnowflakeConfig::from_hashmap(param.properties)?; + Ok(SnowflakeSink { + config, + schema, + pk_indices: param.downstream_pk, + is_append_only: param.sink_type.is_append_only(), + }) + } +} + +pub struct SnowflakeSinkWriter { + config: SnowflakeConfig, + schema: Schema, + pk_indices: Vec, + is_append_only: bool, + /// the client used to send `insertFiles` post request + http_client: SnowflakeHttpClient, + /// the client to insert file to external storage (i.e., s3) + s3_client: SnowflakeS3Client, + row_encoder: JsonEncoder, + row_counter: u32, + payload: String, + /// the threshold for sinking to s3 + max_batch_row_num: u32, + /// The current epoch, used in naming the sink files + /// mainly used for debugging purpose + epoch: u64, +} + +impl SnowflakeSinkWriter { + pub async fn new( + config: SnowflakeConfig, + schema: Schema, + pk_indices: Vec, + is_append_only: bool, + ) -> Self { + let http_client = SnowflakeHttpClient::new( + config.common.account_identifier.clone(), + config.common.user.clone(), + config.common.database.clone(), + config.common.schema.clone(), + config.common.pipe.clone(), + config.common.rsa_public_key_fp.clone(), + config.common.private_key.clone(), + HashMap::new(), + config.common.s3_path.clone(), + ); + + let s3_client = SnowflakeS3Client::new( + config.common.s3_bucket.clone(), + config.common.s3_path.clone(), + config.common.aws_access_key_id.clone(), + config.common.aws_secret_access_key.clone(), + config.common.aws_region.clone(), + ) + .await; + + let max_batch_row_num = config + .common + .max_batch_row_num + .clone() + .parse::() + .expect("failed to parse `snowflake.max_batch_row_num` as a `u32`"); + + Self { + config, + schema: schema.clone(), + pk_indices, + is_append_only, + http_client, + s3_client, + row_encoder: JsonEncoder::new( + schema, + None, + super::encoder::DateHandlingMode::String, + TimestampHandlingMode::String, + TimestamptzHandlingMode::UtcString, + TimeHandlingMode::String, + ), + row_counter: 0, + payload: String::new(), + max_batch_row_num, + // initial value of `epoch` will start from 0 + epoch: 0, + } + } + + /// reset the `payload` and `row_counter`. + /// shall *only* be called after a successful sink. + fn reset(&mut self) { + self.payload.clear(); + self.row_counter = 0; + } + + fn at_sink_threshold(&self) -> bool { + self.row_counter >= self.max_batch_row_num + } + + fn append_only(&mut self, chunk: StreamChunk) -> Result<()> { + for (op, row) in chunk.rows() { + assert_eq!(op, Op::Insert, "expect all `op(s)` to be `Op::Insert`"); + let row_json_string = Value::Object(self.row_encoder.encode(row)?).to_string(); + self.payload.push_str(&row_json_string); + self.row_counter += 1; + } + Ok(()) + } + + fn update_epoch(&mut self, epoch: u64) { + self.epoch = epoch; + } + + /// generate a *global unique* uuid, + /// which is the key to the uniqueness of file suffix. + fn gen_uuid() -> Uuid { + Uuid::new_v4() + } + + /// construct the *global unique* file suffix for the sink. + /// note: this is unique even across multiple parallel writer(s). + fn file_suffix(&self) -> String { + // the format of suffix will be _ + format!("{}_{}", self.epoch, Self::gen_uuid()) + } + + /// sink `payload` to s3, then trigger corresponding `insertFiles` post request + /// to snowflake, to finish the overall sinking pipeline. + async fn sink_payload(&mut self) -> Result<()> { + if self.payload.is_empty() { + return Ok(()); + } + // todo: change this to streaming upload + // first sink to the external stage provided by user (i.e., s3) + self.s3_client + .sink_to_s3(self.payload.clone().into(), self.file_suffix()) + .await?; + // then trigger `insertFiles` post request to snowflake + self.http_client.send_request(self.file_suffix()).await?; + // reset `payload` & `row_counter` + self.reset(); + Ok(()) + } +} + +#[async_trait] +impl SinkWriter for SnowflakeSinkWriter { + async fn begin_epoch(&mut self, epoch: u64) -> Result<()> { + self.update_epoch(epoch); + Ok(()) + } + + async fn abort(&mut self) -> Result<()> { + Ok(()) + } + + async fn update_vnode_bitmap(&mut self, _vnode_bitmap: Arc) -> Result<()> { + Ok(()) + } + + async fn barrier(&mut self, is_checkpoint: bool) -> Result { + if is_checkpoint { + // sink all the row(s) currently batched in `self.payload` + self.sink_payload().await?; + } + Ok(()) + } + + async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> { + self.append_only(chunk)?; + + // When the number of row exceeds `MAX_BATCH_ROW_NUM` + if self.at_sink_threshold() { + self.sink_payload().await?; + } + + Ok(()) + } +} diff --git a/src/connector/src/sink/snowflake_connector.rs b/src/connector/src/sink/snowflake_connector.rs new file mode 100644 index 0000000000000..e5e37deb14652 --- /dev/null +++ b/src/connector/src/sink/snowflake_connector.rs @@ -0,0 +1,257 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; + +use aws_config; +use aws_config::meta::region::RegionProviderChain; +use aws_sdk_s3::config::Credentials; +use aws_sdk_s3::primitives::ByteStream; +use aws_sdk_s3::Client as S3Client; +use aws_types::region::Region; +use bytes::Bytes; +use http::header; +use http::request::Builder; +use hyper::body::Body; +use hyper::client::HttpConnector; +use hyper::{Client, Request, StatusCode}; +use hyper_tls::HttpsConnector; +use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; +use serde::{Deserialize, Serialize}; + +use super::doris_starrocks_connector::POOL_IDLE_TIMEOUT; +use super::{Result, SinkError}; + +const SNOWFLAKE_HOST_ADDR: &str = "snowflakecomputing.com"; +const SNOWFLAKE_REQUEST_ID: &str = "RW_SNOWFLAKE_SINK"; +const S3_INTERMEDIATE_FILE_NAME: &str = "RW_SNOWFLAKE_S3_SINK_FILE"; + +/// The helper function to generate the *global unique* s3 file name. +fn generate_s3_file_name(s3_path: Option, suffix: String) -> String { + match s3_path { + Some(path) => format!("{}/{}_{}", path, S3_INTERMEDIATE_FILE_NAME, suffix), + None => format!("{}_{}", S3_INTERMEDIATE_FILE_NAME, suffix), + } +} + +/// Claims is used when constructing `jwt_token` +/// with payload specified. +/// reference: +#[derive(Debug, Serialize, Deserialize)] +struct Claims { + iss: String, + sub: String, + iat: usize, + exp: usize, +} + +#[derive(Debug)] +pub struct SnowflakeHttpClient { + url: String, + rsa_public_key_fp: String, + account: String, + user: String, + private_key: String, + header: HashMap, + s3_path: Option, +} + +impl SnowflakeHttpClient { + pub fn new( + account: String, + user: String, + db: String, + schema: String, + pipe: String, + rsa_public_key_fp: String, + private_key: String, + header: HashMap, + s3_path: Option, + ) -> Self { + // todo: ensure if we need user to *explicitly* provide the `request_id` + // currently it seems that this is not important. + // reference to the snowpipe rest api is as below, i.e., + // + let url = format!( + "https://{}.{}/v1/data/pipes/{}.{}.{}/insertFiles?requestId={}", + account.clone(), + SNOWFLAKE_HOST_ADDR, + db, + schema, + pipe, + SNOWFLAKE_REQUEST_ID + ); + + Self { + url, + rsa_public_key_fp, + account, + user, + private_key, + header, + s3_path, + } + } + + /// Generate a 59-minutes valid `jwt_token` for authentication of snowflake side + /// And please note that we will NOT strictly counting the time interval + /// of `jwt_token` expiration. + /// Which essentially means that this method should be called *every time* we want + /// to send `insertFiles` request to snowflake server. + fn generate_jwt_token(&self) -> Result { + let header = Header::new(Algorithm::RS256); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as usize; + let lifetime = 59 * 60; + + // Ensure the account and username are uppercase + let account = self.account.to_uppercase(); + let user = self.user.to_uppercase(); + + // Construct the fully qualified username + let qualified_username = format!("{}.{}", account, user); + + let claims = Claims { + iss: format!("{}.{}", qualified_username.clone(), self.rsa_public_key_fp), + sub: qualified_username, + iat: now, + exp: now + lifetime, + }; + + let jwt_token = encode( + &header, + &claims, + &EncodingKey::from_rsa_pem(self.private_key.as_ref()).map_err(|err| { + SinkError::Snowflake(format!( + "failed to encode from provided rsa pem key, error: {}", + err + )) + })?, + ) + .map_err(|err| { + SinkError::Snowflake(format!("failed to encode jwt_token, error: {}", err)) + })?; + Ok(jwt_token) + } + + fn build_request_and_client(&self) -> (Builder, Client>) { + let builder = Request::post(self.url.clone()); + + let connector = HttpsConnector::new(); + let client = Client::builder() + .pool_idle_timeout(POOL_IDLE_TIMEOUT) + .build(connector); + + (builder, client) + } + + /// NOTE: this function should ONLY be called *after* + /// uploading files to remote external staged storage, i.e., AWS S3 + pub async fn send_request(&self, file_suffix: String) -> Result<()> { + let (builder, client) = self.build_request_and_client(); + + // Generate the jwt_token + let jwt_token = self.generate_jwt_token()?; + let builder = builder + .header(header::CONTENT_TYPE, "text/plain") + .header("Authorization", format!("Bearer {}", jwt_token)) + .header( + "X-Snowflake-Authorization-Token-Type".to_string(), + "KEYPAIR_JWT", + ); + + let request = builder + .body(Body::from(generate_s3_file_name( + self.s3_path.clone(), + file_suffix, + ))) + .map_err(|err| SinkError::Snowflake(err.to_string()))?; + + let response = client + .request(request) + .await + .map_err(|err| SinkError::Snowflake(err.to_string()))?; + + if response.status() != StatusCode::OK { + return Err(SinkError::Snowflake(format!( + "failed to make http request, error code: {}\ndetailed response: {:#?}", + response.status(), + response, + ))); + } + + Ok(()) + } +} + +/// todo: refactor this part after s3 sink is available +pub struct SnowflakeS3Client { + s3_bucket: String, + s3_path: Option, + s3_client: S3Client, +} + +impl SnowflakeS3Client { + pub async fn new( + s3_bucket: String, + s3_path: Option, + aws_access_key_id: String, + aws_secret_access_key: String, + aws_region: String, + ) -> Self { + let credentials = Credentials::new( + aws_access_key_id, + aws_secret_access_key, + // we don't allow temporary credentials + None, + None, + "rw_sink_to_s3_credentials", + ); + + let region = RegionProviderChain::first_try(Region::new(aws_region)).or_default_provider(); + + let config = aws_config::from_env() + .credentials_provider(credentials) + .region(region) + .load() + .await; + + // create the brand new s3 client used to sink files to s3 + let s3_client = S3Client::new(&config); + + Self { + s3_bucket, + s3_path, + s3_client, + } + } + + pub async fn sink_to_s3(&self, data: Bytes, file_suffix: String) -> Result<()> { + self.s3_client + .put_object() + .bucket(self.s3_bucket.clone()) + .key(generate_s3_file_name(self.s3_path.clone(), file_suffix)) + .body(ByteStream::from(data)) + .send() + .await + .map_err(|err| { + SinkError::Snowflake(format!("failed to sink data to S3, error: {}", err)) + })?; + + Ok(()) + } +} diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index b287bcd6aa4b4..07da6a36a0e3a 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -528,6 +528,60 @@ RedisConfig: - name: redis.url field_type: String required: true +SnowflakeConfig: + fields: + - name: snowflake.database + field_type: String + comments: The snowflake database used for sinking + required: true + - name: snowflake.schema + field_type: String + comments: The corresponding schema where sink table exists + required: true + - name: snowflake.pipe + field_type: String + comments: The created pipe object, will be used as `insertFiles` target + required: true + - name: snowflake.account_identifier + field_type: String + comments: 'The unique, snowflake provided `account_identifier` NOTE: please use the form `-` For detailed guidance, reference: ' + required: true + - name: snowflake.user + field_type: String + comments: 'The user that owns the table to be sinked NOTE: the user should''ve been granted corresponding *role* reference: ' + required: true + - name: snowflake.rsa_public_key_fp + field_type: String + comments: 'The public key fingerprint used when generating custom `jwt_token` reference: ' + required: true + - name: snowflake.private_key + field_type: String + comments: The rsa pem key *without* encryption + required: true + - name: snowflake.s3_bucket + field_type: String + comments: The s3 bucket where intermediate sink files will be stored + required: true + - name: snowflake.s3_path + field_type: String + comments: The optional s3 path to be specified the actual file location would be `:///` if this field is specified by user(s) otherwise it would be `://` + required: false + - name: snowflake.aws_access_key_id + field_type: String + comments: s3 credentials + required: true + - name: snowflake.aws_secret_access_key + field_type: String + comments: s3 credentials + required: true + - name: snowflake.aws_region + field_type: String + comments: The s3 region, e.g., us-east-2 + required: true + - name: snowflake.max_batch_row_num + field_type: String + comments: The configurable max row(s) to batch, which should be *explicitly* specified by user(s) + required: true StarrocksConfig: fields: - name: starrocks.host