Skip to content

Commit

Permalink
Introduce pluggable backend storage for the HTTP layer.
Browse files Browse the repository at this point in the history
Fixes #478
  • Loading branch information
crodas committed Dec 7, 2024
1 parent 9cb684e commit e7efec3
Show file tree
Hide file tree
Showing 17 changed files with 558 additions and 129 deletions.
178 changes: 129 additions & 49 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions crates/cdk-axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ serde_json = "1"
paste = "1.0.15"
serde = { version = "1.0.210", features = ["derive"] }
uuid = { version = "1", features = ["v4", "serde"] }
sha2 = "0.10.8"
redis = { version = "0.23.3", features = ["tokio-rustls-comp"] }

[features]
swagger = ["cdk/swagger", "dep:utoipa"]
38 changes: 38 additions & 0 deletions crates/cdk-axum/src/cache/backend/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::time::Duration;

use moka::future::Cache;

use crate::cache::{HttpCacheKey, HttpCacheStorage};

/// In memory cache storage for the HTTP cache.
///
/// This is the default cache storage backend, which is used if no other storage
/// backend is provided, or if the provided storage backend is `None`.
///
/// The cache is limited to 10,000 entries and it is not shared between
/// instances nor persisted.
pub struct InMemoryHttpCache(pub Cache<HttpCacheKey, Vec<u8>>);

#[async_trait::async_trait]
impl HttpCacheStorage for InMemoryHttpCache {
fn new(cache_ttl: Duration, cache_tti: Duration) -> Self
where
Self: Sized,
{
InMemoryHttpCache(
Cache::builder()
.max_capacity(10_000)
.time_to_live(cache_ttl)
.time_to_idle(cache_tti)
.build(),
)
}

async fn get(&self, key: &HttpCacheKey) -> Option<Vec<u8>> {
self.0.get(key)
}

async fn set(&self, key: HttpCacheKey, value: Vec<u8>) {
self.0.insert(key, value).await;
}
}
5 changes: 5 additions & 0 deletions crates/cdk-axum/src/cache/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod memory;
mod redis;

pub use self::memory::InMemoryHttpCache;
pub use self::redis::{Config as RedisConfig, HttpCacheRedis};
103 changes: 103 additions & 0 deletions crates/cdk-axum/src/cache/backend/redis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use std::time::Duration;

use redis::AsyncCommands;
use serde::{Deserialize, Serialize};

use crate::cache::{HttpCacheKey, HttpCacheStorage};

/// Redis cache storage for the HTTP cache.
///
/// This cache storage backend uses Redis to store the cache.
pub struct HttpCacheRedis {
cache_ttl: Duration,
prefix: Option<Vec<u8>>,
client: Option<redis::Client>,
}

/// Configuration for the Redis cache storage.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
/// Commong key prefix
pub key_prefix: Option<String>,

/// Connection string to the Redis server.
pub connection_string: String,
}

impl HttpCacheRedis {
/// Create a new Redis cache.
pub fn set_client(mut self, client: redis::Client) -> Self {
self.client = Some(client);
self
}

/// Set a prefix for the cache keys.
///
/// This is useful to have all the HTTP cache keys under a common prefix,
/// some sort of namespace, to make management of the database easier.
pub fn set_prefix(mut self, prefix: Vec<u8>) -> Self {
self.prefix = Some(prefix);
self
}
}

#[async_trait::async_trait]
impl HttpCacheStorage for HttpCacheRedis {
fn new(cache_ttl: Duration, _cache_tti: Duration) -> Self {
Self {
cache_ttl,
prefix: None,
client: None,
}
}

async fn get(&self, key: &HttpCacheKey) -> Option<Vec<u8>> {
let mut con = match self
.client
.as_ref()
.expect("A client must be set with set_client()")
.get_multiplexed_tokio_connection()
.await
{
Ok(con) => con,
Err(err) => {
tracing::error!("Failed to get redis connection: {:?}", err);
return None;
}
};

let mut db_key = self.prefix.clone().unwrap_or_default();
db_key.extend(&**key);

con.get(db_key)
.await
.map_err(|err| {
tracing::error!("Failed to get value from redis: {:?}", err);
err
})
.ok()
}

async fn set(&self, key: HttpCacheKey, value: Vec<u8>) {
let mut db_key = self.prefix.clone().unwrap_or_default();
db_key.extend(&*key);

let mut con = match self
.client
.as_ref()
.expect("A client must be set with set_client()")
.get_multiplexed_tokio_connection()
.await
{
Ok(con) => con,
Err(err) => {
tracing::error!("Failed to get redis connection: {:?}", err);
return;
}
};

let _: Result<(), _> = con
.set_ex(db_key, value, self.cache_ttl.as_secs() as usize)
.await;
}
}
25 changes: 25 additions & 0 deletions crates/cdk-axum/src/cache/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(tag = "backend")]
#[serde(rename_all = "lowercase")]
pub enum Backend {
#[default]
Memory,
Redis(super::backend::RedisConfig),
}

/// Cache configuration.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
/// Cache backend.
#[serde(default)]
#[serde(flatten)]
pub backend: Backend,

/// Time to live for the cache entries.
pub ttl: Option<u64>,

/// Time for the cache entries to be idle.
pub tti: Option<u64>,
}
176 changes: 176 additions & 0 deletions crates/cdk-axum/src/cache/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
//! HTTP cache.
//!
//! This is mod defines a common trait to define custom backends for the HTTP cache.
//!
//! The HTTP cache is a layer to cache responses from HTTP requests, to avoid hitting
//! the same endpoint multiple times, which can be expensive and slow, or to provide
//! idempotent operations.
//!
//! This mod also provides common backend implementations as well, such as In
//! Memory (default) and Redis.
use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;

use serde::de::DeserializeOwned;
use serde::Serialize;
use sha2::{Digest, Sha256};

mod backend;
mod config;

pub use self::backend::*;
pub use self::config::Config;

#[async_trait::async_trait]
/// Cache storage for the HTTP cache.
pub trait HttpCacheStorage {
/// Create a new cache storage instance
fn new(cache_ttl: Duration, cache_tti: Duration) -> Self
where
Self: Sized;

/// Get a value from the cache.
async fn get(&self, key: &HttpCacheKey) -> Option<Vec<u8>>;

/// Set a value in the cache.
async fn set(&self, key: HttpCacheKey, value: Vec<u8>);
}

/// Http cache with a pluggable storage backend.
pub struct HttpCache {
/// Time to live for the cache.
pub ttl: Duration,
/// Time to idle for the cache.
pub tti: Duration,
storage: Arc<dyn HttpCacheStorage + Send + Sync>,
}

impl Default for HttpCache {
fn default() -> Self {
Self::new(
Duration::from_secs(DEFAULT_TTL_SECS),
Duration::from_secs(DEFAULT_TTI_SECS),
None,
)
}
}

/// Max payload size for the cache key.
///
/// This is a trade-off between security and performance. A large payload can be used to
/// perform a CPU attack.
const MAX_PAYLOAD_SIZE: usize = 10 * 1024 * 1024;

/// Default TTL for the cache.
const DEFAULT_TTL_SECS: u64 = 60;

/// Default TTI for the cache.
const DEFAULT_TTI_SECS: u64 = 60;

/// Http cache key.
///
/// This type ensures no Vec<u8> is used as a key, which is error-prone.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct HttpCacheKey([u8; 32]);

impl Deref for HttpCacheKey {
type Target = [u8; 32];

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl From<config::Config> for HttpCache {
fn from(config: config::Config) -> Self {
match config.backend {
config::Backend::Memory => Self::new(
Duration::from_secs(config.ttl.unwrap_or(DEFAULT_TTL_SECS)),
Duration::from_secs(config.tti.unwrap_or(DEFAULT_TTI_SECS)),
None,
),
config::Backend::Redis(redis_config) => {
let client = redis::Client::open(redis_config.connection_string)
.expect("Failed to create Redis client");
let storage = HttpCacheRedis::new(
Duration::from_secs(config.ttl.unwrap_or(60)),
Duration::from_secs(config.tti.unwrap_or(60)),
)
.set_client(client)
.set_prefix(
redis_config
.key_prefix
.unwrap_or_default()
.as_bytes()
.to_vec(),
);
Self::new(
Duration::from_secs(config.ttl.unwrap_or(DEFAULT_TTL_SECS)),
Duration::from_secs(config.tti.unwrap_or(DEFAULT_TTI_SECS)),
Some(Arc::new(storage)),
)
}
}
}
}

impl HttpCache {
/// Create a new HTTP cache.
pub fn new(
ttl: Duration,
tti: Duration,
storage: Option<Arc<dyn HttpCacheStorage + Send + Sync + 'static>>,
) -> Self {
Self {
ttl,
tti,
storage: storage.unwrap_or_else(|| Arc::new(InMemoryHttpCache::new(ttl, tti))),
}
}

/// Calculate a cache key from a serializable value.
///
/// Usually the input is the request body or query parameters.
///
/// The result is an optional cache key. If the key cannot be calculated, it
/// will be None, meaning the value cannot be cached, therefore the entire
/// caching mechanism should be skipped.
///
/// Instead of using the entire serialized input as the key, the key is a
/// double hash to have a predictable key size, although it may open the
/// window for CPU attacks with large payloads, but it is a trade-off.
/// Perhaps upper layer have a protection against large payloads.
pub fn calculate_key<K: Serialize>(&self, key: &K) -> Option<HttpCacheKey> {
let json_value = match serde_json::to_vec(key) {
Ok(value) => value,
Err(err) => {
tracing::warn!("Failed to serialize key: {:?}", err);
return None;
}
};

if json_value.len() > MAX_PAYLOAD_SIZE {
tracing::warn!("Key size is too large: {}", json_value.len());
return None;
}

let first_hash = Sha256::digest(json_value);
let second_hash = Sha256::digest(first_hash);
Some(HttpCacheKey(second_hash.into()))
}

/// Get a value from the cache.
pub async fn get<V: DeserializeOwned>(self: &Arc<Self>, key: &HttpCacheKey) -> Option<V> {
self.storage
.get(key)
.await
.map(|value| serde_json::from_slice(&value).unwrap())
}

/// Set a value in the cache.
pub async fn set<V: Serialize>(self: &Arc<Self>, key: HttpCacheKey, value: &V) {
let value = serde_json::to_vec(value).unwrap();
self.storage.set(key, value).await;
}
}
Loading

0 comments on commit e7efec3

Please sign in to comment.