From ff0e710ebbd9e30622a98c73c89cda7635613551 Mon Sep 17 00:00:00 2001 From: Simone <79767264+smonicas@users.noreply.github.com> Date: Thu, 25 Jan 2024 13:03:16 +0100 Subject: [PATCH] Add safe external calls option (#59) --- src/cli/commands/detect/mod.rs | 5 +++++ src/cli/commands/print/mod.rs | 5 +++++ src/core/core_unit.rs | 12 +++++++++++- src/detectors/read_only_reentrancy.rs | 19 +++++++++++++++---- src/detectors/reentrancy.rs | 19 +++++++++++++++---- src/detectors/reentrancy_benign.rs | 19 +++++++++++++++---- src/detectors/reentrancy_events.rs | 19 +++++++++++++++---- tests/detectors/read_only_reentrancy.cairo | 9 ++++++++- tests/detectors/reentrancy.cairo | 8 ++++++++ tests/detectors/reentrancy_benign.cairo | 7 +++++++ tests/detectors/reentrancy_events.cairo | 7 +++++++ tests/integration_tests.rs | 1 + 12 files changed, 112 insertions(+), 18 deletions(-) diff --git a/src/cli/commands/detect/mod.rs b/src/cli/commands/detect/mod.rs index c85cae5..c2064c2 100644 --- a/src/cli/commands/detect/mod.rs +++ b/src/cli/commands/detect/mod.rs @@ -22,6 +22,10 @@ pub struct DetectArgs { #[arg(long, num_args(0..))] contract_path: Option>, + /// Functions name that are safe when called (e.g. they don't cause a reentrancy) + #[arg(long, num_args(0..))] + safe_external_calls: Option>, + /// Detectors to run #[arg(long, num_args(0..), conflicts_with_all(["exclude", "exclude_informational", "exclude_low", "exclude_medium", "exclude_high"]))] detect: Option>, @@ -53,6 +57,7 @@ impl From<&DetectArgs> for CoreOpts { target: args.target.clone(), corelib: args.corelib.clone(), contract_path: args.contract_path.clone(), + safe_external_calls: args.safe_external_calls.clone(), } } } diff --git a/src/cli/commands/print/mod.rs b/src/cli/commands/print/mod.rs index e517b09..54d3709 100644 --- a/src/cli/commands/print/mod.rs +++ b/src/cli/commands/print/mod.rs @@ -18,6 +18,10 @@ pub struct PrintArgs { #[arg(long, num_args(0..))] contract_path: Option>, + /// Functions name that are safe when called (e.g. they don't cause a reentrancy) + #[arg(long, num_args(0..))] + safe_external_calls: Option>, + /// Which functions to run the printer (all, user-functions) #[arg(short, long, default_value_t = Filter::UserFunctions)] filter: Filter, @@ -33,6 +37,7 @@ impl From<&PrintArgs> for CoreOpts { target: args.target.clone(), corelib: args.corelib.clone(), contract_path: args.contract_path.clone(), + safe_external_calls: args.safe_external_calls.clone(), } } } diff --git a/src/core/core_unit.rs b/src/core/core_unit.rs index e895f48..241a3a2 100644 --- a/src/core/core_unit.rs +++ b/src/core/core_unit.rs @@ -10,14 +10,17 @@ pub struct CoreOpts { pub target: PathBuf, pub corelib: Option, pub contract_path: Option>, + pub safe_external_calls: Option>, } pub struct CoreUnit { compilation_units: Vec, + safe_external_calls: Option>, } impl CoreUnit { pub fn new(opts: CoreOpts) -> Result { + let safe_external_calls = opts.safe_external_calls.clone(); let program_compiled = compile(opts)?; let compilation_units = program_compiled .par_iter() @@ -31,10 +34,17 @@ impl CoreUnit { compilation_unit }) .collect(); - Ok(CoreUnit { compilation_units }) + Ok(CoreUnit { + compilation_units, + safe_external_calls, + }) } pub fn get_compilation_units(&self) -> &Vec { &self.compilation_units } + + pub fn get_safe_external_calls(&self) -> &Option> { + &self.safe_external_calls + } } diff --git a/src/detectors/read_only_reentrancy.rs b/src/detectors/read_only_reentrancy.rs index 1d717d1..7864673 100644 --- a/src/detectors/read_only_reentrancy.rs +++ b/src/detectors/read_only_reentrancy.rs @@ -58,6 +58,20 @@ impl Detector for ReadOnlyReentrancy { } = bb_info.1 { for call in reentrancy_info.external_calls.iter() { + let external_function_call = format!( + "{}", + call.get_external_call().as_ref().unwrap().get_statement() + ); + + if let Some(safe_external_calls) = core.get_safe_external_calls() { + if safe_external_calls + .iter() + .any(|f_name| external_function_call.contains(f_name)) + { + continue; + } + } + for written_variable in reentrancy_info.storage_variables_written.iter() { let written_variable_name = written_variable @@ -82,10 +96,7 @@ impl Detector for ReadOnlyReentrancy { message: format!( "Read only reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}", view_function, - call.get_external_call() - .as_ref() - .unwrap() - .get_statement(), + external_function_call, call.get_function(), written_variable .get_storage_variable_written() diff --git a/src/detectors/reentrancy.rs b/src/detectors/reentrancy.rs index 0290388..b240022 100644 --- a/src/detectors/reentrancy.rs +++ b/src/detectors/reentrancy.rs @@ -38,6 +38,20 @@ impl Detector for Reentrancy { } = bb_info.1 { for call in reentrancy_info.external_calls.iter() { + let external_function_call = format!( + "{}", + call.get_external_call().as_ref().unwrap().get_statement() + ); + + if let Some(safe_external_calls) = core.get_safe_external_calls() { + if safe_external_calls + .iter() + .any(|f_name| external_function_call.contains(f_name)) + { + continue; + } + } + if let Some(current_vars_read_before_call) = reentrancy_info .variables_read_before_calls .iter() @@ -79,10 +93,7 @@ impl Detector for Reentrancy { message: format!( "Reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}.", f.name(), - call.get_external_call() - .as_ref() - .unwrap() - .get_statement(), + external_function_call, call.get_function(), written_variable .get_storage_variable_written() diff --git a/src/detectors/reentrancy_benign.rs b/src/detectors/reentrancy_benign.rs index 491fe1d..e85ebb6 100644 --- a/src/detectors/reentrancy_benign.rs +++ b/src/detectors/reentrancy_benign.rs @@ -38,6 +38,20 @@ impl Detector for ReentrancyBenign { } = bb_info.1 { for call in reentrancy_info.external_calls.iter() { + let external_function_call = format!( + "{}", + call.get_external_call().as_ref().unwrap().get_statement() + ); + + if let Some(safe_external_calls) = core.get_safe_external_calls() { + if safe_external_calls + .iter() + .any(|f_name| external_function_call.contains(f_name)) + { + continue; + } + } + if let Some(current_vars_read_before_call) = reentrancy_info .variables_read_before_calls .iter() @@ -79,10 +93,7 @@ impl Detector for ReentrancyBenign { message: format!( "Reentrancy in {}\n\tExternal call {} done in {}\n\tVariable written after {} in {}.", f.name(), - call.get_external_call() - .as_ref() - .unwrap() - .get_statement(), + external_function_call, call.get_function(), written_variable .get_storage_variable_written() diff --git a/src/detectors/reentrancy_events.rs b/src/detectors/reentrancy_events.rs index 414ea0c..5ef7215 100644 --- a/src/detectors/reentrancy_events.rs +++ b/src/detectors/reentrancy_events.rs @@ -39,6 +39,20 @@ impl Detector for ReentrancyEvents { { for event in reentrancy_info.events.iter() { for call in reentrancy_info.external_calls.iter() { + let external_function_call = format!( + "{}", + call.get_external_call().as_ref().unwrap().get_statement() + ); + + if let Some(safe_external_calls) = core.get_safe_external_calls() { + if safe_external_calls + .iter() + .any(|f_name| external_function_call.contains(f_name)) + { + continue; + } + } + results.insert(Result { name: self.name().to_string(), impact: self.impact(), @@ -46,10 +60,7 @@ impl Detector for ReentrancyEvents { message: format!( "Reentrancy in {}\n\tExternal call {} done in {}\n\tEvent emitted after {} in {}.", f.name(), - call.get_external_call() - .as_ref() - .unwrap() - .get_statement(), + external_function_call, call.get_function(), event.get_event_emitted().as_ref().unwrap().get_statement(), event.get_function() diff --git a/tests/detectors/read_only_reentrancy.cairo b/tests/detectors/read_only_reentrancy.cairo index 5059723..dee4876 100644 --- a/tests/detectors/read_only_reentrancy.cairo +++ b/tests/detectors/read_only_reentrancy.cairo @@ -1,6 +1,7 @@ #[starknet::interface] trait IAnotherContract { fn foo(self: @T, a: felt252); + fn safe_foo(self: @T, a: felt252); } #[starknet::contract] @@ -32,8 +33,14 @@ mod TestContract { } #[external(v0)] - fn ok(ref self: ContractState, address: ContractAddress) { + fn good1(ref self: ContractState, address: ContractAddress) { IAnotherContractDispatcher { contract_address: address }.foo(4); } + #[external(v0)] + fn good2(ref self: ContractState, address: ContractAddress) { + IAnotherContractDispatcher { contract_address: address }.safe_foo(4); + self.a.write(4); + } + } diff --git a/tests/detectors/reentrancy.cairo b/tests/detectors/reentrancy.cairo index c16d44f..7e7f7c6 100644 --- a/tests/detectors/reentrancy.cairo +++ b/tests/detectors/reentrancy.cairo @@ -1,6 +1,7 @@ #[starknet::interface] trait IAnotherContract { fn foo(self: @T, a: felt252); + fn safe_foo(self: @T, a: felt252); } #[starknet::contract] @@ -22,6 +23,13 @@ mod TestContract { IAnotherContractDispatcher { contract_address: address }.foo(a); } + #[external(v0)] + fn good2(ref self: ContractState, address: ContractAddress) { + let a = self.a.read(); + IAnotherContractDispatcher { contract_address: address }.safe_foo(a); + self.a.write(4); + } + #[external(v0)] fn bad1(ref self: ContractState, address: ContractAddress) { let a = self.a.read(); diff --git a/tests/detectors/reentrancy_benign.cairo b/tests/detectors/reentrancy_benign.cairo index c56b9f0..845c3d3 100644 --- a/tests/detectors/reentrancy_benign.cairo +++ b/tests/detectors/reentrancy_benign.cairo @@ -1,6 +1,7 @@ #[starknet::interface] trait IAnotherContract { fn foo(self: @T, a: felt252); + fn safe_foo(self: @T, a: felt252); } #[starknet::contract] @@ -22,6 +23,12 @@ mod TestContract { IAnotherContractDispatcher { contract_address: address }.foo(a); } + #[external(v0)] + fn good2(ref self: ContractState, address: ContractAddress) { + IAnotherContractDispatcher { contract_address: address }.safe_foo(4); + self.a.write(4); + } + #[external(v0)] fn bad1(ref self: ContractState, address: ContractAddress) { IAnotherContractDispatcher { contract_address: address }.foo(4); diff --git a/tests/detectors/reentrancy_events.cairo b/tests/detectors/reentrancy_events.cairo index 4469965..41b2218 100644 --- a/tests/detectors/reentrancy_events.cairo +++ b/tests/detectors/reentrancy_events.cairo @@ -1,6 +1,7 @@ #[starknet::interface] trait IAnotherContract { fn foo(self: @T, a: felt252); + fn safe_foo(self: @T, a: felt252); } #[starknet::contract] @@ -27,6 +28,12 @@ mod TestContract { IAnotherContractDispatcher { contract_address: address }.foo(4); } + #[external(v0)] + fn good2(ref self: ContractState, address: ContractAddress) { + IAnotherContractDispatcher { contract_address: address }.safe_foo(4); + self.emit(MyEvent { }); + } + #[external(v0)] fn bad1(ref self: ContractState, address: ContractAddress) { IAnotherContractDispatcher { contract_address: address }.foo(4); diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index c1d570e..2b7c504 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -12,6 +12,7 @@ fn test_detectors() { env::var("CARGO_MANIFEST_DIR").unwrap() + "/corelib/src", )), contract_path: None, + safe_external_calls: Some(vec!["::safe_foo".to_string()]), }; let core = CoreUnit::new(opts).unwrap(); let mut results = get_detectors()