Skip to content

Commit

Permalink
fix(katana): search the inner state when sierra class not found in th…
Browse files Browse the repository at this point in the history
…e cache (#1602)

fall back to inner state when sierra class not found in the cache
  • Loading branch information
kariy authored Mar 3, 2024
1 parent 1d3ecb6 commit 5c74ce2
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 41 deletions.
4 changes: 2 additions & 2 deletions crates/katana/core/src/service/block_producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl IntervalBlockProducer {
trace!(target: "miner", "created new block: {}", outcome.block_number);

backend.update_block_env(&mut block_env);
pending_state.reset_state(new_state.into(), block_env, cfg_env);
pending_state.reset_state(StateRefDb(new_state), block_env, cfg_env);

Ok(outcome)
}
Expand Down Expand Up @@ -410,7 +410,7 @@ impl InstantBlockProducer {
let block_context = block_context_from_envs(&block_env, &cfg_env);

let latest_state = StateFactoryProvider::latest(backend.blockchain.provider())?;
let state = CachedStateWrapper::new(latest_state.into());
let state = CachedStateWrapper::new(StateRefDb(latest_state));

let txs = transactions.iter().map(TxWithHash::from);

Expand Down
19 changes: 8 additions & 11 deletions crates/katana/executor/src/blockifier/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub mod utils;
use std::sync::Arc;

use blockifier::block_context::BlockContext;
use blockifier::state::state_api::StateReader;
use blockifier::transaction::errors::TransactionExecutionError;
use blockifier::transaction::objects::TransactionExecutionInfo;
use blockifier::transaction::transaction_execution::Transaction;
Expand Down Expand Up @@ -34,15 +33,15 @@ type TxExecutionResult = Result<TransactionExecutionInfo, TransactionExecutionEr
/// The transactions will be executed in an iterator fashion, sequentially, in the
/// exact order they are provided to the executor. The execution is done within its
/// implementation of the [`Iterator`] trait.
pub struct TransactionExecutor<'a, S: StateReader, T> {
pub struct TransactionExecutor<'a, T> {
/// A flag to enable/disable fee charging.
charge_fee: bool,
/// The block context the transactions will be executed on.
block_context: &'a BlockContext,
/// The transactions to be executed (in the exact order they are in the iterator).
transactions: T,
/// The state the transactions will be executed on.
state: &'a CachedStateWrapper<S>,
state: &'a CachedStateWrapper,
/// A flag to enable/disable transaction validation.
validate: bool,

Expand All @@ -52,13 +51,12 @@ pub struct TransactionExecutor<'a, S: StateReader, T> {
resources_log: bool,
}

impl<'a, S, T> TransactionExecutor<'a, S, T>
impl<'a, T> TransactionExecutor<'a, T>
where
S: StateReader,
T: Iterator<Item = ExecutableTxWithHash>,
{
pub fn new(
state: &'a CachedStateWrapper<S>,
state: &'a CachedStateWrapper,
block_context: &'a BlockContext,
charge_fee: bool,
validate: bool,
Expand Down Expand Up @@ -94,9 +92,8 @@ where
}
}

impl<'a, S, T> Iterator for TransactionExecutor<'a, S, T>
impl<'a, T> Iterator for TransactionExecutor<'a, T>
where
S: StateReader,
T: Iterator<Item = ExecutableTxWithHash>,
{
type Item = TxExecutionResult;
Expand Down Expand Up @@ -141,9 +138,9 @@ where
}
}

fn execute_tx<S: StateReader>(
fn execute_tx(
tx: ExecutableTxWithHash,
state: &CachedStateWrapper<S>,
state: &CachedStateWrapper,
block_context: &BlockContext,
charge_fee: bool,
validate: bool,
Expand Down Expand Up @@ -184,7 +181,7 @@ pub struct PendingState {
/// The block context of the pending block.
pub block_envs: RwLock<(BlockEnv, CfgEnv)>,
/// The state of the pending block.
pub state: Arc<CachedStateWrapper<StateRefDb>>,
pub state: Arc<CachedStateWrapper>,
/// The transactions that have been executed.
pub executed_txs: RwLock<Vec<AcceptedTxPair>>,
/// The transactions that have been rejected.
Expand Down
85 changes: 60 additions & 25 deletions crates/katana/executor/src/blockifier/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,58 @@ use starknet_api::state::StorageKey;
/// A state db only provide read access.
///
/// This type implements the [`StateReader`] trait so that it can be used as a with [`CachedState`].
pub struct StateRefDb(Box<dyn StateProvider>);
pub struct StateRefDb(pub Box<dyn StateProvider>);

impl StateRefDb {
pub fn new(provider: impl StateProvider + 'static) -> Self {
Self(Box::new(provider))
}
}

impl<T> From<T> for StateRefDb
where
T: StateProvider + 'static,
{
fn from(provider: T) -> Self {
Self::new(provider)
impl ContractClassProvider for StateRefDb {
fn class(
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<katana_primitives::contract::CompiledContractClass>> {
self.0.class(hash)
}

fn compiled_class_hash_of_class_hash(
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<katana_primitives::contract::CompiledClassHash>> {
self.0.compiled_class_hash_of_class_hash(hash)
}

fn sierra_class(
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<FlattenedSierraClass>> {
self.0.sierra_class(hash)
}
}

impl StateProvider for StateRefDb {
fn nonce(
&self,
address: katana_primitives::contract::ContractAddress,
) -> ProviderResult<Option<katana_primitives::contract::Nonce>> {
self.0.nonce(address)
}

fn class_hash_of_contract(
&self,
address: katana_primitives::contract::ContractAddress,
) -> ProviderResult<Option<katana_primitives::contract::ClassHash>> {
self.0.class_hash_of_contract(address)
}

fn storage(
&self,
address: katana_primitives::contract::ContractAddress,
storage_key: katana_primitives::contract::StorageKey,
) -> ProviderResult<Option<katana_primitives::contract::StorageValue>> {
self.0.storage(address, storage_key)
}
}

Expand Down Expand Up @@ -93,25 +131,27 @@ impl StateReader for StateRefDb {
}
}

pub struct CachedStateWrapper<S: StateReader> {
inner: Mutex<CachedState<S>>,
pub struct CachedStateWrapper {
inner: Mutex<CachedState<StateRefDb>>,
sierra_class: RwLock<HashMap<katana_primitives::contract::ClassHash, FlattenedSierraClass>>,
}

impl<S: StateReader> CachedStateWrapper<S> {
pub fn new(db: S) -> Self {
impl CachedStateWrapper {
pub fn new(db: StateRefDb) -> Self {
Self {
sierra_class: Default::default(),
inner: Mutex::new(CachedState::new(db, GlobalContractCache::default())),
}
}

pub(super) fn reset_with_new_state(&self, db: S) {
pub(super) fn reset_with_new_state(&self, db: StateRefDb) {
*self.inner() = CachedState::new(db, GlobalContractCache::default());
self.sierra_class_mut().clear();
}

pub fn inner(&self) -> parking_lot::lock_api::MutexGuard<'_, RawMutex, CachedState<S>> {
pub fn inner(
&self,
) -> parking_lot::lock_api::MutexGuard<'_, RawMutex, CachedState<StateRefDb>> {
self.inner.lock()
}

Expand All @@ -134,10 +174,7 @@ impl<S: StateReader> CachedStateWrapper<S> {
}
}

impl<Db> ContractClassProvider for CachedStateWrapper<Db>
where
Db: StateReader + Sync + Send,
{
impl ContractClassProvider for CachedStateWrapper {
fn class(
&self,
hash: katana_primitives::contract::ClassHash,
Expand All @@ -162,17 +199,15 @@ where
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<FlattenedSierraClass>> {
let class @ Some(_) = self.sierra_class().get(&hash).cloned() else {
return Ok(None);
};
Ok(class)
if let Some(class) = self.sierra_class().get(&hash) {
Ok(Some(class.clone()))
} else {
self.inner.lock().state.0.sierra_class(hash)
}
}
}

impl<Db> StateProvider for CachedStateWrapper<Db>
where
Db: StateReader + Sync + Send,
{
impl StateProvider for CachedStateWrapper {
fn storage(
&self,
address: katana_primitives::contract::ContractAddress,
Expand Down
6 changes: 3 additions & 3 deletions crates/katana/executor/src/blockifier/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub fn estimate_fee(
state: Box<dyn StateProvider>,
validate: bool,
) -> Result<Vec<FeeEstimate>, TransactionExecutionError> {
let state = CachedStateWrapper::new(StateRefDb::from(state));
let state = CachedStateWrapper::new(StateRefDb(state));
let results = TransactionExecutor::new(&state, &block_context, true, validate, transactions)
.with_error_log()
.execute();
Expand Down Expand Up @@ -94,7 +94,7 @@ pub fn raw_call(
state: Box<dyn StateProvider>,
initial_gas: u64,
) -> Result<CallInfo, TransactionExecutionError> {
let mut state = CachedState::new(StateRefDb::from(state), GlobalContractCache::default());
let mut state = CachedState::new(StateRefDb(state), GlobalContractCache::default());
let mut state = CachedState::new(MutRefState::new(&mut state), GlobalContractCache::default());

let call = CallEntryPoint {
Expand Down Expand Up @@ -235,7 +235,7 @@ pub(crate) fn pretty_print_resources(resources: &ResourcesMapping) -> String {
}

pub fn get_state_update_from_cached_state(
state: &CachedStateWrapper<StateRefDb>,
state: &CachedStateWrapper,
) -> StateUpdatesWithDeclaredClasses {
let state_diff = state.inner().to_state_diff();

Expand Down

0 comments on commit 5c74ce2

Please sign in to comment.