From a74ab4db8a0de0e5020bbcd334be02409d4cdc2a Mon Sep 17 00:00:00 2001 From: Sean Young Date: Thu, 21 Sep 2023 21:05:30 +0100 Subject: [PATCH] Allow using {func} for type to use library functions (#1528) Fixes https://github.com/hyperledger/solang/issues/1525 Signed-off-by: Sean Young --- src/sema/contracts.rs | 5 +- src/sema/expression/function_call.rs | 7 +- src/sema/namespace.rs | 14 +- src/sema/using.rs | 32 ++++- .../solana/using_functions.sol | 31 +++++ tests/evm.rs | 2 +- tests/polkadot_tests/libraries.rs | 2 +- tests/solana_tests/using.rs | 127 ++++++++++++++++++ 8 files changed, 207 insertions(+), 13 deletions(-) create mode 100644 tests/contract_testcases/solana/using_functions.sol diff --git a/src/sema/contracts.rs b/src/sema/contracts.rs index 61fdb5e0c..d7ae42f63 100644 --- a/src/sema/contracts.rs +++ b/src/sema/contracts.rs @@ -63,8 +63,6 @@ impl ast::Contract { /// Resolve the following contract pub fn resolve(contracts: &[ContractDefinition], file_no: usize, ns: &mut ast::Namespace) { - resolve_using(contracts, file_no, ns); - // we need to resolve declarations first, so we call functions/constructors of // contracts before they are declared let mut delayed: ResolveLater = Default::default(); @@ -73,6 +71,9 @@ pub fn resolve(contracts: &[ContractDefinition], file_no: usize, ns: &mut ast::N resolve_declarations(def, file_no, ns, &mut delayed); } + // using may use functions declared in contracts + resolve_using(contracts, file_no, ns); + // Resolve base contract constructor arguments on contract definition (not constructor definitions) resolve_base_args(contracts, file_no, ns); diff --git a/src/sema/expression/function_call.rs b/src/sema/expression/function_call.rs index 2cf923595..3eea2bae7 100644 --- a/src/sema/expression/function_call.rs +++ b/src/sema/expression/function_call.rs @@ -1324,7 +1324,6 @@ pub(super) fn method_call_pos_args( resolve_to, )? { return Ok(resolved_call); - } else { } if let Some(resolved_call) = try_user_type( @@ -1352,8 +1351,9 @@ pub(super) fn method_call_pos_args( if let Some(mut path) = ns.expr_to_identifier_path(var) { path.identifiers.push(func.clone()); - if let Ok(list) = ns.resolve_free_function_with_namespace( + if let Ok(list) = ns.resolve_function_with_namespace( context.file_no, + None, &path, &mut Diagnostics::default(), ) { @@ -1640,8 +1640,9 @@ pub(super) fn method_call_named_args( if let Some(mut path) = ns.expr_to_identifier_path(var) { path.identifiers.push(func_name.clone()); - if let Ok(list) = ns.resolve_free_function_with_namespace( + if let Ok(list) = ns.resolve_function_with_namespace( context.file_no, + None, &path, &mut Diagnostics::default(), ) { diff --git a/src/sema/namespace.rs b/src/sema/namespace.rs index 9a5f4b5c7..aa9d360ac 100644 --- a/src/sema/namespace.rs +++ b/src/sema/namespace.rs @@ -364,9 +364,10 @@ impl Namespace { } /// Resolve a free function name with namespace - pub(super) fn resolve_free_function_with_namespace( + pub(super) fn resolve_function_with_namespace( &mut self, file_no: usize, + contract_no: Option, name: &pt::IdentifierPath, diagnostics: &mut Diagnostics, ) -> Result, ()> { @@ -376,12 +377,12 @@ impl Namespace { .map(|(id, namespace)| (id, namespace.iter().collect())) .unwrap(); - let s = self.resolve_namespace(namespace, file_no, None, id, diagnostics)?; + let symbol = self.resolve_namespace(namespace, file_no, contract_no, id, diagnostics)?; - if let Some(Symbol::Function(list)) = s { + if let Some(Symbol::Function(list)) = symbol { Ok(list.clone()) } else { - let error = Namespace::wrong_symbol(s, id); + let error = Namespace::wrong_symbol(symbol, id); diagnostics.push(error); @@ -1335,6 +1336,7 @@ impl Namespace { )); return Err(()); }; + namespace.clear(); Some(*n) } Some(Symbol::Function(_)) => { @@ -1390,6 +1392,10 @@ impl Namespace { }; } + if !namespace.is_empty() { + return Ok(None); + } + let mut s = self .variable_symbols .get(&(import_file_no, contract_no, id.name.to_owned())) diff --git a/src/sema/using.rs b/src/sema/using.rs index 231f618ef..c63dd77cd 100644 --- a/src/sema/using.rs +++ b/src/sema/using.rs @@ -94,8 +94,9 @@ pub(crate) fn using_decl( for using_function in functions { let function_name = &using_function.path; - if let Ok(list) = ns.resolve_free_function_with_namespace( + if let Ok(list) = ns.resolve_function_with_namespace( file_no, + contract_no, &using_function.path, &mut diagnostics, ) { @@ -120,6 +121,18 @@ pub(crate) fn using_decl( let func = &ns.functions[func_no]; + if let Some(contract_no) = func.contract_no { + if !ns.contracts[contract_no].is_library() { + diagnostics.push(Diagnostic::error_with_note( + function_name.loc, + format!("'{function_name}' is not a library function"), + func.loc, + format!("definition of {}", using_function.path), + )); + continue; + } + } + if func.params.is_empty() { diagnostics.push(Diagnostic::error_with_note( function_name.loc, @@ -251,7 +264,22 @@ pub(crate) fn using_decl( Some(oper) } else { if let Some(ty) = &ty { - if *ty != func.params[0].ty { + let dummy = Expression::Variable { + loc, + ty: ty.clone(), + var_no: 0, + }; + + if dummy + .cast( + &loc, + &func.params[0].ty, + true, + ns, + &mut Diagnostics::default(), + ) + .is_err() + { diagnostics.push(Diagnostic::error_with_note( function_name.loc, format!("function cannot be used since first argument is '{}' rather than the required '{}'", func.params[0].ty.to_string(ns), ty.to_string(ns)), diff --git a/tests/contract_testcases/solana/using_functions.sol b/tests/contract_testcases/solana/using_functions.sol new file mode 100644 index 000000000..814c48fe5 --- /dev/null +++ b/tests/contract_testcases/solana/using_functions.sol @@ -0,0 +1,31 @@ +contract C { + function foo(int256 a) internal pure returns (int256) { + return a; + } +} + +library L { + function bar(int256 a) internal pure returns (int256) { + return a; + } +} + +library Lib { + function baz(int256 a, bool b) internal pure returns (int256) { + if (b) { + return 1; + } else { + return a; + } + } + using {L.bar, baz} for int256; +} + +library Lib2 { + using {L.foo.bar, C.foo} for int256; +} + +// ---- Expect: diagnostics ---- +// error: 25:15-18: 'foo' not found +// error: 25:20-25: 'C.foo' is not a library function +// note 2:2-55: definition of C.foo diff --git a/tests/evm.rs b/tests/evm.rs index c3e114c95..6027ac301 100644 --- a/tests/evm.rs +++ b/tests/evm.rs @@ -249,7 +249,7 @@ fn ethereum_solidity_tests() { }) .sum(); - assert_eq!(errors, 1024); + assert_eq!(errors, 1018); } fn set_file_contents(source: &str, path: &Path) -> (FileResolver, Vec) { diff --git a/tests/polkadot_tests/libraries.rs b/tests/polkadot_tests/libraries.rs index fabff28f2..282283084 100644 --- a/tests/polkadot_tests/libraries.rs +++ b/tests/polkadot_tests/libraries.rs @@ -74,7 +74,7 @@ fn using() { let mut runtime = build_solidity( r##" contract test { - using ints for uint32; + using {ints.max} for uint32; function foo(uint32 x) public pure returns (uint64) { // x is 32 bit but the max function takes 64 bit uint return x.max(65536); diff --git a/tests/solana_tests/using.rs b/tests/solana_tests/using.rs index cf93da594..00308d73a 100644 --- a/tests/solana_tests/using.rs +++ b/tests/solana_tests/using.rs @@ -255,3 +255,130 @@ contract C { assert_eq!(res, BorshToken::Bool(true)); } + +#[test] +fn using_function_for_struct() { + let mut vm = build_solidity( + r#" +struct Pet { + string name; + uint8 age; +} + +library Info { + function isCat(Pet memory myPet) public pure returns (bool) { + return myPet.name == "cat"; + } + + function setAge(Pet memory myPet, uint8 age) pure public { + myPet.age = age; + } +} + +contract C { + using {Info.isCat, Info.setAge} for Pet; + + function testPet(string memory name, uint8 age) pure public returns (bool) { + Pet memory my_pet = Pet(name, age); + return my_pet.isCat(); + } + + function changeAge(Pet memory myPet) public pure returns (Pet memory) { + myPet.setAge(5); + return myPet; + } + +} + "#, + ); + + let data_account = vm.initialize_data_account(); + + vm.function("new") + .accounts(vec![("dataAccount", data_account)]) + .call(); + + let res = vm + .function("testPet") + .arguments(&[ + BorshToken::String("cat".to_string()), + BorshToken::Uint { + width: 8, + value: BigInt::from(2u8), + }, + ]) + .call() + .unwrap(); + + assert_eq!(res, BorshToken::Bool(true)); + + let res = vm + .function("changeAge") + .arguments(&[BorshToken::Tuple(vec![ + BorshToken::String("cat".to_string()), + BorshToken::Uint { + width: 8, + value: BigInt::from(2u8), + }, + ])]) + .call() + .unwrap(); + + assert_eq!( + res, + BorshToken::Tuple(vec![ + BorshToken::String("cat".to_string()), + BorshToken::Uint { + width: 8, + value: BigInt::from(5u8), + } + ]) + ); +} + +#[test] +fn using_function_overload() { + let mut vm = build_solidity( + r#" + library LibInLib { + function get0(bytes x) public pure returns (bytes1) { + return x[0]; + } + + function get1(bytes x) public pure returns (bytes1) { + return x[1]; + } + } + + library MyBytes { + using {LibInLib.get0, LibInLib.get1} for bytes; + + function push(bytes memory b, uint8[] memory a) pure public returns (bool) { + return b.get0() == a[0] && b.get1()== a[1]; + } + } + + contract C { + using {MyBytes.push} for bytes; + + function check() public pure returns (bool) { + bytes memory b; + b.push(1); + b.push(2); + uint8[] memory vec = new uint8[](2); + vec[0] = 1; + vec[1] = 2; + return b.push(vec); + } + }"#, + ); + + let data_account = vm.initialize_data_account(); + vm.function("new") + .accounts(vec![("dataAccount", data_account)]) + .call(); + + let res = vm.function("check").call().unwrap(); + + assert_eq!(res, BorshToken::Bool(true)); +}