From c2294d8c2c7c98f26ee9a420cb8081b4633346c7 Mon Sep 17 00:00:00 2001 From: everpcpc Date: Thu, 19 Dec 2024 20:03:42 +0800 Subject: [PATCH] feat(query): add config jwks_refresh_interval & jwks_refresh_timeout (#17087) * feat(query): add config jwks_refresh_interval & jwks_refresh_timeout * fix: remove force reload when key not found * z * z * z * z * z * z * z * z * z --- .github/actions/setup_build_tool/action.yml | 1 + src/query/config/src/config.rs | 13 +++- src/query/config/src/inner.rs | 4 ++ src/query/service/src/auth.rs | 2 + .../storages/testdata/configs_table_basic.txt | 2 + src/query/users/src/jwt/authenticator.rs | 22 +++++- src/query/users/src/jwt/jwk.rs | 72 +++++++++---------- src/query/users/tests/it/jwt/authenticator.rs | 34 +-------- 8 files changed, 77 insertions(+), 73 deletions(-) diff --git a/.github/actions/setup_build_tool/action.yml b/.github/actions/setup_build_tool/action.yml index 1fdd39e9796a..68b616fcbe63 100644 --- a/.github/actions/setup_build_tool/action.yml +++ b/.github/actions/setup_build_tool/action.yml @@ -36,6 +36,7 @@ runs: EOF RUNNER_PROVIDER="${RUNNER_PROVIDER:-github}" + export SCCACHE_IDLE_TIMEOUT=0 case ${RUNNER_PROVIDER} in aws) echo "setting up sccache for AWS S3..." diff --git a/src/query/config/src/config.rs b/src/query/config/src/config.rs index facc0ac17584..3b3766acb323 100644 --- a/src/query/config/src/config.rs +++ b/src/query/config/src/config.rs @@ -1528,7 +1528,14 @@ pub struct QueryConfig { #[clap(long, value_name = "VALUE", default_value_t)] pub jwt_key_file: String, - /// If there are multiple trusted jwt provider put it into additional_jwt_key_files configuration + /// Interval in seconds to refresh jwks + #[clap(long, value_name = "VALUE", default_value = "600")] + pub jwks_refresh_interval: u64, + + /// Timeout in seconds to refresh jwks + #[clap(long, value_name = "VALUE", default_value = "10")] + pub jwks_refresh_timeout: u64, + #[clap(skip)] pub jwt_key_files: Vec, @@ -1754,6 +1761,8 @@ impl TryInto for QueryConfig { max_storage_io_requests: self.max_storage_io_requests, jwt_key_file: self.jwt_key_file, jwt_key_files: self.jwt_key_files, + jwks_refresh_interval: self.jwks_refresh_interval, + jwks_refresh_timeout: self.jwks_refresh_timeout, default_storage_format: self.default_storage_format, default_compression: self.default_compression, builtin: BuiltInConfig { @@ -1845,6 +1854,8 @@ impl From for QueryConfig { max_storage_io_requests: inner.max_storage_io_requests, jwt_key_file: inner.jwt_key_file, jwt_key_files: inner.jwt_key_files, + jwks_refresh_interval: inner.jwks_refresh_interval, + jwks_refresh_timeout: inner.jwks_refresh_timeout, default_storage_format: inner.default_storage_format, default_compression: inner.default_compression, users: inner.builtin.users, diff --git a/src/query/config/src/inner.rs b/src/query/config/src/inner.rs index f94d33602a93..829f7abd338d 100644 --- a/src/query/config/src/inner.rs +++ b/src/query/config/src/inner.rs @@ -215,6 +215,8 @@ pub struct QueryConfig { pub jwt_key_file: String, pub jwt_key_files: Vec, + pub jwks_refresh_interval: u64, + pub jwks_refresh_timeout: u64, pub default_storage_format: String, pub default_compression: String, pub builtin: BuiltInConfig, @@ -301,6 +303,8 @@ impl Default for QueryConfig { max_storage_io_requests: None, jwt_key_file: "".to_string(), jwt_key_files: Vec::new(), + jwks_refresh_interval: 600, + jwks_refresh_timeout: 10, default_storage_format: "auto".to_string(), default_compression: "auto".to_string(), builtin: BuiltInConfig::default(), diff --git a/src/query/service/src/auth.rs b/src/query/service/src/auth.rs index ab9c3d0e7b97..2561c8232c92 100644 --- a/src/query/service/src/auth.rs +++ b/src/query/service/src/auth.rs @@ -85,6 +85,8 @@ impl AuthMgr { jwt_auth: JwtAuthenticator::create( cfg.query.jwt_key_file.clone(), cfg.query.jwt_key_files.clone(), + cfg.query.jwks_refresh_interval, + cfg.query.jwks_refresh_timeout, ), }) } diff --git a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt index ca92eb184070..35ea94d3b6d3 100644 --- a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt +++ b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt @@ -100,6 +100,8 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo | 'query' | 'http_handler_tls_server_root_ca_cert' | '' | '' | | 'query' | 'internal_enable_sandbox_tenant' | 'false' | '' | | 'query' | 'internal_merge_on_read_mutation' | 'false' | '' | +| 'query' | 'jwks_refresh_interval' | '600' | '' | +| 'query' | 'jwks_refresh_timeout' | '10' | '' | | 'query' | 'jwt_key_file' | '' | '' | | 'query' | 'jwt_key_files' | '' | '' | | 'query' | 'management_mode' | 'false' | '' | diff --git a/src/query/users/src/jwt/authenticator.rs b/src/query/users/src/jwt/authenticator.rs index 1a3a6884f2d9..c3802f6e9783 100644 --- a/src/query/users/src/jwt/authenticator.rs +++ b/src/query/users/src/jwt/authenticator.rs @@ -78,14 +78,30 @@ impl CustomClaims { } impl JwtAuthenticator { - pub fn create(jwt_key_file: String, jwt_key_files: Vec) -> Option { + pub fn create( + jwt_key_file: String, + jwt_key_files: Vec, + jwks_refresh_interval: u64, + jwks_refresh_timeout: u64, + ) -> Option { if jwt_key_file.is_empty() && jwt_key_files.is_empty() { return None; } // init a vec of key store - let mut key_stores = vec![jwk::JwkKeyStore::new(jwt_key_file)]; + let mut key_stores = vec![]; + if !jwt_key_file.is_empty() { + key_stores.push( + jwk::JwkKeyStore::new(jwt_key_file) + .with_refresh_interval(jwks_refresh_interval) + .with_refresh_timeout(jwks_refresh_timeout), + ); + } for u in jwt_key_files { - key_stores.push(jwk::JwkKeyStore::new(u)) + key_stores.push( + jwk::JwkKeyStore::new(u) + .with_refresh_interval(jwks_refresh_interval) + .with_refresh_timeout(jwks_refresh_timeout), + ); } Some(JwtAuthenticator { key_stores }) } diff --git a/src/query/users/src/jwt/jwk.rs b/src/query/users/src/jwt/jwk.rs index 4baf5bc0b168..00495492ef9f 100644 --- a/src/query/users/src/jwt/jwk.rs +++ b/src/query/users/src/jwt/jwk.rs @@ -33,7 +33,8 @@ use serde::Serialize; use super::PubKey; -const JWK_REFRESH_INTERVAL: u64 = 15; +const JWKS_REFRESH_TIMEOUT: u64 = 10; +const JWKS_REFRESH_INTERVAL: u64 = 600; #[derive(Debug, Serialize, Deserialize)] pub struct JwkKey { @@ -99,17 +100,17 @@ pub struct JwkKeyStore { cached_keys: Arc>>, pub(crate) last_refreshed_at: RwLock>, pub(crate) refresh_interval: Duration, + pub(crate) refresh_timeout: Duration, pub(crate) load_keys_func: Option HashMap + Send + Sync>>, } impl JwkKeyStore { pub fn new(url: String) -> Self { - let refresh_interval = Duration::from_secs(JWK_REFRESH_INTERVAL * 60); - let keys = Arc::new(RwLock::new(HashMap::new())); Self { url, - cached_keys: keys, - refresh_interval, + cached_keys: Arc::new(RwLock::new(HashMap::new())), + refresh_interval: Duration::from_secs(JWKS_REFRESH_INTERVAL), + refresh_timeout: Duration::from_secs(JWKS_REFRESH_TIMEOUT), last_refreshed_at: RwLock::new(None), load_keys_func: None, } @@ -124,6 +125,16 @@ impl JwkKeyStore { self } + pub fn with_refresh_interval(mut self, interval: u64) -> Self { + self.refresh_interval = Duration::from_secs(interval); + self + } + + pub fn with_refresh_timeout(mut self, timeout: u64) -> Self { + self.refresh_timeout = Duration::from_secs(timeout); + self + } + pub fn url(&self) -> String { self.url.clone() } @@ -136,12 +147,19 @@ impl JwkKeyStore { return Ok(load_keys_func()); } - let response = reqwest::get(&self.url).await.map_err(|e| { + let client = reqwest::Client::builder() + .timeout(self.refresh_timeout) + .build() + .map_err(|e| { + ErrorCode::InvalidConfig(format!("Failed to create jwks client: {}", e)) + })?; + let response = client.get(&self.url).send().await.map_err(|e| { ErrorCode::AuthenticateFailure(format!("Could not download JWKS: {}", e)) })?; - let body = response.text().await.unwrap(); - let jwk_keys = serde_json::from_str::(&body) - .map_err(|e| ErrorCode::InvalidConfig(format!("Failed to parse keys: {}", e)))?; + let jwk_keys: JwkKeys = response + .json() + .await + .map_err(|e| ErrorCode::InvalidConfig(format!("Failed to parse JWKS: {}", e)))?; let mut new_keys: HashMap = HashMap::new(); for k in &jwk_keys.keys { new_keys.insert(k.kid.to_string(), k.get_public_key()?); @@ -166,6 +184,7 @@ impl JwkKeyStore { let new_keys = match self.load_keys().await { Ok(new_keys) => new_keys, Err(err) => { + warn!("Failed to load JWKS: {}", err); if !old_keys.is_empty() { return Ok(old_keys); } @@ -177,9 +196,9 @@ impl JwkKeyStore { if !new_keys.keys().eq(old_keys.keys()) { info!("JWKS keys changed."); } - *self.cached_keys.write() = new_keys; + *self.cached_keys.write() = new_keys.clone(); self.last_refreshed_at.write().replace(Instant::now()); - Ok(old_keys) + Ok(new_keys) } #[async_backtrace::framed] @@ -200,31 +219,12 @@ impl JwkKeyStore { } }; - // happy path: the key_id is found in the store - if let Some(key) = keys.get(&key_id) { - return Ok(key.clone()); + match keys.get(&key_id) { + None => Err(ErrorCode::AuthenticateFailure(format!( + "key id {} not found in jwk store", + key_id + ))), + Some(key) => Ok(key.clone()), } - - // if the key_id is not set here, it might because the JWKS has been rotated, we need to refresh it. - warn!( - "key_id {} not found in jwks store, try to reload keys", - key_id - ); - let keys = self - .load_keys_with_cache(true) - .await - .map_err(|e| e.add_message("failed to reload JWKS keys"))?; - - let key = match keys.get(&key_id) { - None => { - return Err(ErrorCode::AuthenticateFailure(format!( - "key id {} not found in jwk store", - key_id - ))); - } - Some(key) => key.clone(), - }; - - Ok(key) } } diff --git a/src/query/users/tests/it/jwt/authenticator.rs b/src/query/users/tests/it/jwt/authenticator.rs index 5e86c82ac9d6..dc94bf5d1489 100644 --- a/src/query/users/tests/it/jwt/authenticator.rs +++ b/src/query/users/tests/it/jwt/authenticator.rs @@ -12,18 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; -use std::sync::Arc; - use base64::engine::general_purpose; use base64::prelude::*; use databend_common_base::base::tokio; use databend_common_exception::Result; -use databend_common_users::JwkKeyStore; use databend_common_users::JwtAuthenticator; -use databend_common_users::PubKey; use jwt_simple::prelude::*; use wiremock::matchers::method; use wiremock::matchers::path; @@ -60,7 +53,7 @@ async fn test_parse_non_custom_claim() -> Result<()> { .mount(&server) .await; let first_url = format!("http://{}{}", server.address(), json_path); - let auth = JwtAuthenticator::create(first_url, vec![]).unwrap(); + let auth = JwtAuthenticator::create(first_url, vec![], 600, 10).unwrap(); let user_name = "test-user2"; let my_additional_data = MyAdditionalData { user_is_admin: false, @@ -74,28 +67,3 @@ async fn test_parse_non_custom_claim() -> Result<()> { assert_eq!(res.custom.role, None); Ok(()) } - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn test_jwk_key_store_retry_on_key_not_found() -> Result<()> { - let func_calls = Arc::new(AtomicUsize::new(0)); - let func_calls_cloned = func_calls.clone(); - - let mock_load_keys = Arc::new(move || -> HashMap { - let mut keys_map = HashMap::new(); - keys_map.insert( - "key1".to_string(), - PubKey::RSA256(RS256KeyPair::generate(2048).unwrap().public_key().into()), - ); - func_calls_cloned.fetch_add(1, Ordering::SeqCst); - keys_map - }); - let store = JwkKeyStore::new("".to_string()).with_load_keys_func(mock_load_keys); - - let r = store.get_key(Some("key2".to_string())).await; - assert_eq!( - r.unwrap_err().message(), - "key id key2 not found in jwk store" - ); - assert_eq!(func_calls.load(Ordering::SeqCst), 2); - Ok(()) -}