Skip to content

Commit

Permalink
feat!: Add Blockstore::close method
Browse files Browse the repository at this point in the history
  • Loading branch information
oblique committed Sep 9, 2024
1 parent 58994e6 commit d5a8c74
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ redb = { version = "2", optional = true }
# Those can be restored by migrating between versions:
# https://docs.rs/sled/latest/sled/struct.Db.html#examples-1
sled = { version = "0.34.7", optional = true }
tokio = { version = "1.29.0", features = ["macros", "rt"], optional = true }
tokio = { version = "1.29.0", features = ["macros", "rt", "sync"], optional = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
js-sys = { version = "0.3.68", optional = true }
Expand All @@ -36,7 +36,7 @@ wasm-bindgen = { version = "0.2.91", optional = true }

[dev-dependencies]
rstest = "0.22.0"
tokio = { version = "1.29.0", features = ["macros", "rt"] }
tokio = { version = "1.29.0", features = ["macros", "rt", "time"] }
tempfile = "3.10"

# doc-tests
Expand Down
89 changes: 89 additions & 0 deletions src/counter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use std::fmt::{self, Debug};
use std::pin::pin;
use std::sync::Arc;

use tokio::sync::Notify;

pub(crate) struct Counter {
counter: Arc<()>,
notify: Arc<Notify>,
}

pub(crate) struct CounterGuard {
counter: Option<Arc<()>>,
notify: Arc<Notify>,
}

impl Drop for CounterGuard {
fn drop(&mut self) {
self.counter.take();
self.notify.notify_waiters();
}
}

impl Counter {
pub(crate) fn new() -> Counter {
Counter {
counter: Arc::new(()),
notify: Arc::new(Notify::new()),
}
}

pub(crate) fn guard(&self) -> CounterGuard {
CounterGuard {
counter: Some(self.counter.clone()),
notify: self.notify.clone(),
}
}

/// Wait all guards to drop.
pub(crate) async fn wait_guards(&mut self) {
let mut notified = pin!(self.notify.notified());

while Arc::strong_count(&self.counter) > 1 {
notified.as_mut().await;
notified.set(self.notify.notified());
}
}
}

impl Debug for Counter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Counter { .. }")
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::time::{Duration, Instant};
use tokio::spawn;
use tokio::time::{sleep, timeout};

#[tokio::test]
async fn counter_works() {
let mut counter = Counter::new();
counter.wait_guards().await;

let guard1 = counter.guard();
let guard2 = counter.guard();
let now = Instant::now();

spawn(async move {
let _guard = guard1;
sleep(Duration::from_millis(100)).await;
});

spawn(async move {
let _guard = guard2;
sleep(Duration::from_millis(200)).await;
});

timeout(Duration::from_millis(300), counter.wait_guards())
.await
.unwrap();

let elapsed = now.elapsed();
assert!(elapsed >= Duration::from_millis(200) && elapsed < Duration::from_millis(300));
}
}
4 changes: 4 additions & 0 deletions src/in_memory_blockstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ impl<const MAX_MULTIHASH_SIZE: usize> Blockstore for InMemoryBlockstore<MAX_MULT
let cid = convert_cid(cid)?;
Ok(self.contains_cid(&cid))
}

async fn close(self) -> Result<()> {
Ok(())
}
}

impl<const MAX_MULTIHASH_SIZE: usize> Default for InMemoryBlockstore<MAX_MULTIHASH_SIZE> {
Expand Down
5 changes: 5 additions & 0 deletions src/indexed_db_blockstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ impl Blockstore for IndexedDbBlockstore {

has_key(&blocks, &cid).await
}

async fn close(self) -> Result<()> {
self.db.close();
Ok(())
}
}

impl From<rexie::Error> for Error {
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ mod redb_blockstore;
#[cfg(all(not(target_arch = "wasm32"), feature = "sled"))]
mod sled_blockstore;

#[cfg(all(not(target_arch = "wasm32"), any(feature = "redb", feature = "sled")))]
mod counter;

pub use crate::in_memory_blockstore::InMemoryBlockstore;
#[cfg(all(target_arch = "wasm32", feature = "indexeddb"))]
#[cfg_attr(docsrs, doc(cfg(all(target_arch = "wasm32", feature = "indexeddb"))))]
Expand Down Expand Up @@ -164,6 +167,8 @@ pub trait Blockstore: CondSync {
Ok(())
}
}

fn close(self) -> impl Future<Output = Result<()>> + CondSend;
}

pub(crate) fn convert_cid<const S: usize, const NEW_S: usize>(
Expand Down
4 changes: 4 additions & 0 deletions src/lru_blockstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ impl<const MAX_MULTIHASH_SIZE: usize> Blockstore for LruBlockstore<MAX_MULTIHASH
let cache = self.cache.lock().expect("lock failed");
Ok(cache.contains(&cid))
}

async fn close(self) -> Result<()> {
Ok(())
}
}

#[cfg(test)]
Expand Down
49 changes: 36 additions & 13 deletions src/redb_blockstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use redb::{
};
use tokio::task::spawn_blocking;

use crate::counter::Counter;
use crate::{Blockstore, Error, Result};

const BLOCKS_TABLE: TableDefinition<'static, &[u8], &[u8]> =
Expand All @@ -17,6 +18,7 @@ const BLOCKS_TABLE: TableDefinition<'static, &[u8], &[u8]> =
#[derive(Debug)]
pub struct RedbBlockstore {
db: Arc<Database>,
task_counter: Counter,
}

impl RedbBlockstore {
Expand Down Expand Up @@ -56,7 +58,10 @@ impl RedbBlockstore {
/// # }
/// ```
pub fn new(db: Arc<Database>) -> Self {
RedbBlockstore { db }
RedbBlockstore {
db,
task_counter: Counter::new(),
}
}

/// Returns the raw [`redb::Database`].
Expand All @@ -74,10 +79,16 @@ impl RedbBlockstore {
T: Send + 'static,
{
let db = self.db.clone();
let guard = self.task_counter.guard();

tokio::task::spawn_blocking(move || {
let _guard = guard;

spawn_blocking(move || {
let mut tx = db.begin_read()?;
f(&mut tx)
{
let db = db;
let mut tx = db.begin_read()?;
f(&mut tx)
}
})
.await?
}
Expand All @@ -91,18 +102,24 @@ impl RedbBlockstore {
T: Send + 'static,
{
let db = self.db.clone();
let guard = self.task_counter.guard();

spawn_blocking(move || {
let mut tx = db.begin_write()?;
let res = f(&mut tx);
tokio::task::spawn_blocking(move || {
let _guard = guard;

if res.is_ok() {
tx.commit()?;
} else {
tx.abort()?;
}
{
let db = db;
let mut tx = db.begin_write()?;
let res = f(&mut tx);

res
if res.is_ok() {
tx.commit()?;
} else {
tx.abort()?;
}

res
}
})
.await?
}
Expand Down Expand Up @@ -167,6 +184,12 @@ impl Blockstore for RedbBlockstore {
})
.await
}

async fn close(mut self) -> Result<()> {
// Wait all ongoing `spawn_blocking` tasks to finish.
self.task_counter.wait_guards().await;
Ok(())
}
}

impl From<TransactionError> for Error {
Expand Down
34 changes: 29 additions & 5 deletions src/sled_blockstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::sync::Arc;

use cid::CidGeneric;
use sled::{Db, Error as SledError, Tree};
use tokio::task::spawn_blocking;
use tokio::task::{spawn_blocking, JoinHandle};

use crate::counter::Counter;
use crate::{Blockstore, Error, Result};

const BLOCKS_TREE_ID: &[u8] = b"BLOCKSTORE.BLOCKS";
Expand All @@ -12,6 +13,7 @@ const BLOCKS_TREE_ID: &[u8] = b"BLOCKSTORE.BLOCKS";
#[derive(Debug)]
pub struct SledBlockstore {
inner: Arc<Inner>,
task_counter: Counter,
}

#[derive(Debug)]
Expand All @@ -21,6 +23,19 @@ struct Inner {
}

impl SledBlockstore {
fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let guard = self.task_counter.guard();

spawn_blocking(move || {
let _guard = guard;
f()
})
}

/// Create or open a [`SledBlockstore`] in a given sled [`Db`].
///
/// # Example
Expand All @@ -40,6 +55,7 @@ impl SledBlockstore {

Ok(Self {
inner: Arc::new(Inner { _db: db, blocks }),
task_counter: Counter::new(),
})
})
.await?
Expand All @@ -49,15 +65,16 @@ impl SledBlockstore {
let inner = self.inner.clone();
let cid = cid.to_bytes();

spawn_blocking(move || Ok(inner.blocks.get(cid)?.map(|bytes| bytes.to_vec()))).await?
self.spawn_blocking(move || Ok(inner.blocks.get(cid)?.map(|bytes| bytes.to_vec())))
.await?
}

async fn put<const S: usize>(&self, cid: &CidGeneric<S>, data: &[u8]) -> Result<()> {
let inner = self.inner.clone();
let cid = cid.to_bytes();
let data = data.to_vec();

spawn_blocking(move || {
self.spawn_blocking(move || {
let _ = inner
.blocks
.compare_and_swap(cid, None as Option<&[u8]>, Some(data))?;
Expand All @@ -70,7 +87,7 @@ impl SledBlockstore {
let inner = self.inner.clone();
let cid = cid.to_bytes();

spawn_blocking(move || {
self.spawn_blocking(move || {
inner.blocks.remove(cid)?;
Ok(())
})
Expand All @@ -81,7 +98,8 @@ impl SledBlockstore {
let inner = self.inner.clone();
let cid = cid.to_bytes();

spawn_blocking(move || Ok(inner.blocks.contains_key(cid)?)).await?
self.spawn_blocking(move || Ok(inner.blocks.contains_key(cid)?))
.await?
}
}

Expand All @@ -101,6 +119,12 @@ impl Blockstore for SledBlockstore {
async fn has<const S: usize>(&self, cid: &CidGeneric<S>) -> Result<bool> {
self.has(cid).await
}

async fn close(mut self) -> Result<()> {
// Wait all ongoing `spawn_blocking` tasks to finish.
self.task_counter.wait_guards().await;
Ok(())
}
}

// divide errors into recoverable and not avoiding directly relying on passing sled types
Expand Down

0 comments on commit d5a8c74

Please sign in to comment.