Skip to content

Commit

Permalink
Allow using {func} for type to use library functions (#1528)
Browse files Browse the repository at this point in the history
Fixes #1525

Signed-off-by: Sean Young <[email protected]>
  • Loading branch information
seanyoung authored Sep 21, 2023
1 parent c2d0fd8 commit a74ab4d
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 13 deletions.
5 changes: 3 additions & 2 deletions src/sema/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ impl ast::Contract {

/// Resolve the following contract
pub fn resolve(contracts: &[ContractDefinition], file_no: usize, ns: &mut ast::Namespace) {
resolve_using(contracts, file_no, ns);

// we need to resolve declarations first, so we call functions/constructors of
// contracts before they are declared
let mut delayed: ResolveLater = Default::default();
Expand All @@ -73,6 +71,9 @@ pub fn resolve(contracts: &[ContractDefinition], file_no: usize, ns: &mut ast::N
resolve_declarations(def, file_no, ns, &mut delayed);
}

// using may use functions declared in contracts
resolve_using(contracts, file_no, ns);

// Resolve base contract constructor arguments on contract definition (not constructor definitions)
resolve_base_args(contracts, file_no, ns);

Expand Down
7 changes: 4 additions & 3 deletions src/sema/expression/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,6 @@ pub(super) fn method_call_pos_args(
resolve_to,
)? {
return Ok(resolved_call);
} else {
}

if let Some(resolved_call) = try_user_type(
Expand Down Expand Up @@ -1352,8 +1351,9 @@ pub(super) fn method_call_pos_args(
if let Some(mut path) = ns.expr_to_identifier_path(var) {
path.identifiers.push(func.clone());

if let Ok(list) = ns.resolve_free_function_with_namespace(
if let Ok(list) = ns.resolve_function_with_namespace(
context.file_no,
None,
&path,
&mut Diagnostics::default(),
) {
Expand Down Expand Up @@ -1640,8 +1640,9 @@ pub(super) fn method_call_named_args(
if let Some(mut path) = ns.expr_to_identifier_path(var) {
path.identifiers.push(func_name.clone());

if let Ok(list) = ns.resolve_free_function_with_namespace(
if let Ok(list) = ns.resolve_function_with_namespace(
context.file_no,
None,
&path,
&mut Diagnostics::default(),
) {
Expand Down
14 changes: 10 additions & 4 deletions src/sema/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,10 @@ impl Namespace {
}

/// Resolve a free function name with namespace
pub(super) fn resolve_free_function_with_namespace(
pub(super) fn resolve_function_with_namespace(
&mut self,
file_no: usize,
contract_no: Option<usize>,
name: &pt::IdentifierPath,
diagnostics: &mut Diagnostics,
) -> Result<Vec<(pt::Loc, usize)>, ()> {
Expand All @@ -376,12 +377,12 @@ impl Namespace {
.map(|(id, namespace)| (id, namespace.iter().collect()))
.unwrap();

let s = self.resolve_namespace(namespace, file_no, None, id, diagnostics)?;
let symbol = self.resolve_namespace(namespace, file_no, contract_no, id, diagnostics)?;

if let Some(Symbol::Function(list)) = s {
if let Some(Symbol::Function(list)) = symbol {
Ok(list.clone())
} else {
let error = Namespace::wrong_symbol(s, id);
let error = Namespace::wrong_symbol(symbol, id);

diagnostics.push(error);

Expand Down Expand Up @@ -1335,6 +1336,7 @@ impl Namespace {
));
return Err(());
};
namespace.clear();
Some(*n)
}
Some(Symbol::Function(_)) => {
Expand Down Expand Up @@ -1390,6 +1392,10 @@ impl Namespace {
};
}

if !namespace.is_empty() {
return Ok(None);
}

let mut s = self
.variable_symbols
.get(&(import_file_no, contract_no, id.name.to_owned()))
Expand Down
32 changes: 30 additions & 2 deletions src/sema/using.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ pub(crate) fn using_decl(

for using_function in functions {
let function_name = &using_function.path;
if let Ok(list) = ns.resolve_free_function_with_namespace(
if let Ok(list) = ns.resolve_function_with_namespace(
file_no,
contract_no,
&using_function.path,
&mut diagnostics,
) {
Expand All @@ -120,6 +121,18 @@ pub(crate) fn using_decl(

let func = &ns.functions[func_no];

if let Some(contract_no) = func.contract_no {
if !ns.contracts[contract_no].is_library() {
diagnostics.push(Diagnostic::error_with_note(
function_name.loc,
format!("'{function_name}' is not a library function"),
func.loc,
format!("definition of {}", using_function.path),
));
continue;
}
}

if func.params.is_empty() {
diagnostics.push(Diagnostic::error_with_note(
function_name.loc,
Expand Down Expand Up @@ -251,7 +264,22 @@ pub(crate) fn using_decl(
Some(oper)
} else {
if let Some(ty) = &ty {
if *ty != func.params[0].ty {
let dummy = Expression::Variable {
loc,
ty: ty.clone(),
var_no: 0,
};

if dummy
.cast(
&loc,
&func.params[0].ty,
true,
ns,
&mut Diagnostics::default(),
)
.is_err()
{
diagnostics.push(Diagnostic::error_with_note(
function_name.loc,
format!("function cannot be used since first argument is '{}' rather than the required '{}'", func.params[0].ty.to_string(ns), ty.to_string(ns)),
Expand Down
31 changes: 31 additions & 0 deletions tests/contract_testcases/solana/using_functions.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
contract C {
function foo(int256 a) internal pure returns (int256) {
return a;
}
}

library L {
function bar(int256 a) internal pure returns (int256) {
return a;
}
}

library Lib {
function baz(int256 a, bool b) internal pure returns (int256) {
if (b) {
return 1;
} else {
return a;
}
}
using {L.bar, baz} for int256;
}

library Lib2 {
using {L.foo.bar, C.foo} for int256;
}

// ---- Expect: diagnostics ----
// error: 25:15-18: 'foo' not found
// error: 25:20-25: 'C.foo' is not a library function
// note 2:2-55: definition of C.foo
2 changes: 1 addition & 1 deletion tests/evm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ fn ethereum_solidity_tests() {
})
.sum();

assert_eq!(errors, 1024);
assert_eq!(errors, 1018);
}

fn set_file_contents(source: &str, path: &Path) -> (FileResolver, Vec<String>) {
Expand Down
2 changes: 1 addition & 1 deletion tests/polkadot_tests/libraries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ fn using() {
let mut runtime = build_solidity(
r##"
contract test {
using ints for uint32;
using {ints.max} for uint32;
function foo(uint32 x) public pure returns (uint64) {
// x is 32 bit but the max function takes 64 bit uint
return x.max(65536);
Expand Down
127 changes: 127 additions & 0 deletions tests/solana_tests/using.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,130 @@ contract C {

assert_eq!(res, BorshToken::Bool(true));
}

#[test]
fn using_function_for_struct() {
let mut vm = build_solidity(
r#"
struct Pet {
string name;
uint8 age;
}
library Info {
function isCat(Pet memory myPet) public pure returns (bool) {
return myPet.name == "cat";
}
function setAge(Pet memory myPet, uint8 age) pure public {
myPet.age = age;
}
}
contract C {
using {Info.isCat, Info.setAge} for Pet;
function testPet(string memory name, uint8 age) pure public returns (bool) {
Pet memory my_pet = Pet(name, age);
return my_pet.isCat();
}
function changeAge(Pet memory myPet) public pure returns (Pet memory) {
myPet.setAge(5);
return myPet;
}
}
"#,
);

let data_account = vm.initialize_data_account();

vm.function("new")
.accounts(vec![("dataAccount", data_account)])
.call();

let res = vm
.function("testPet")
.arguments(&[
BorshToken::String("cat".to_string()),
BorshToken::Uint {
width: 8,
value: BigInt::from(2u8),
},
])
.call()
.unwrap();

assert_eq!(res, BorshToken::Bool(true));

let res = vm
.function("changeAge")
.arguments(&[BorshToken::Tuple(vec![
BorshToken::String("cat".to_string()),
BorshToken::Uint {
width: 8,
value: BigInt::from(2u8),
},
])])
.call()
.unwrap();

assert_eq!(
res,
BorshToken::Tuple(vec![
BorshToken::String("cat".to_string()),
BorshToken::Uint {
width: 8,
value: BigInt::from(5u8),
}
])
);
}

#[test]
fn using_function_overload() {
let mut vm = build_solidity(
r#"
library LibInLib {
function get0(bytes x) public pure returns (bytes1) {
return x[0];
}
function get1(bytes x) public pure returns (bytes1) {
return x[1];
}
}
library MyBytes {
using {LibInLib.get0, LibInLib.get1} for bytes;
function push(bytes memory b, uint8[] memory a) pure public returns (bool) {
return b.get0() == a[0] && b.get1()== a[1];
}
}
contract C {
using {MyBytes.push} for bytes;
function check() public pure returns (bool) {
bytes memory b;
b.push(1);
b.push(2);
uint8[] memory vec = new uint8[](2);
vec[0] = 1;
vec[1] = 2;
return b.push(vec);
}
}"#,
);

let data_account = vm.initialize_data_account();
vm.function("new")
.accounts(vec![("dataAccount", data_account)])
.call();

let res = vm.function("check").call().unwrap();

assert_eq!(res, BorshToken::Bool(true));
}

0 comments on commit a74ab4d

Please sign in to comment.