diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 8887bf07cb..9eb13a44d2 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -770,7 +770,122 @@ def foo(s: MyStruct) -> MyStruct: out = compile_code(code, contract_path="code.vy", output_formats=["interface"])["interface"] assert "# Structs" in out - assert "struct MyStruct:" in out + assert "MyStruct:" in out assert "b: uint256" in out - assert "struct Voter:" in out + assert "Voter:" in out assert "voted: bool" in out + + +def test_interface_with_imported_structures(make_input_bundle): + a = """ +import b + +struct Foo: + val:uint256 + """ + b = """ +import c + +struct Bar: + val:uint256 + """ + c = """ +struct Baz: + val:uint256 + """ + input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c}) + out = compile_code( + a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"] + )["interface"] + + assert "# Structs" in out + assert "Foo:" in out + assert "b Bar:" in out + assert "b.c Baz:" in out + + +def test_interface_with_doubly_imported_structure(make_input_bundle): + a = """ +import b +import c + +struct Foo: + val:uint256 + """ + b = """ +import d + +struct Bar: + val:uint256 + """ + c = """ +import d +struct Baz: + val:uint256 + """ + d = """ +struct Boo: + val:uint256 + """ + + input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c, "d.vy": d}) + out = compile_code( + a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"] + )["interface"] + + assert "# Structs" in out + assert "Foo:" in out + assert "b Bar:" in out + assert "c Baz" in out + assert out.count("Boo") == 1 + + +def test_interface_with_imported_struct_via_interface(make_input_bundle): + a = """ +import b + +struct Foo: + val:uint256 + """ + b = """ +struct Bar: + val:uint256 + + """ + + input_bundle = make_input_bundle({"a.vy": a, "b.vyi": b}) + out = compile_code( + a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"] + )["interface"] + print(out) + assert "# Structs" in out + assert "Foo:" in out + assert "b Bar:" in out + + +def test_interface_with_imported_structs_via_interface(make_input_bundle): + a = """ +import b +import c + +struct Foo: + val:uint256 + """ + b = """ +struct Bar: + val:uint256 + """ + c = """ +struct Baz: + val:uint256 + """ + + input_bundle = make_input_bundle({"a.vy": a, "b.vyi": b, "c.vy": c}) + out = compile_code( + a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"] + )["interface"] + print(out) + assert "# Structs" in out + assert "Foo:" in out + assert "b Bar:" in out + assert "c Baz:" in out diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index f5f99a0bc3..b018639fd2 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -14,7 +14,7 @@ from vyper.ir import compile_ir from vyper.semantics.analysis.base import ModuleInfo from vyper.semantics.types.function import FunctionVisibility, StateMutability -from vyper.semantics.types.module import InterfaceT +from vyper.semantics.types.module import InterfaceT, ModuleT, StructT from vyper.typing import StorageLayout from vyper.utils import vyper_warn from vyper.warnings import ContractSizeLimitWarning @@ -125,13 +125,15 @@ def build_external_interface_output(compiler_data: CompilerData) -> str: def build_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.annotated_vyper_module._metadata["type"].interface + module_t = compiler_data.annotated_vyper_module._metadata["type"] + interface = module_t.interface out = "" - if len(interface.structs) > 0: + structs = _get_structs(module_t) + if len(structs) > 0: out += "# Structs\n\n" - for struct in interface.structs.values(): - out += f"struct {struct.name}:\n" + for prefix, struct in structs: + out += f"struct " f"struct {prefix[1:] + ' ' if prefix else ''}{struct.name}:\n" for member_name, member_type in struct.members.items(): out += f" {member_name}: {member_type}\n" out += "\n\n" @@ -159,6 +161,25 @@ def build_interface_output(compiler_data: CompilerData) -> str: return out +def _get_structs(m: ModuleT, prefix="", visited: set[ModuleT] = None) -> list: + visited = visited or set() + if m in visited: + return [] + visited.add(m) + + structs = [(prefix, val) for val in m.interface.structs.values()] + + for alias, interface in m.interfaces.items(): + structs += [(prefix + "." + alias, val) for val in interface.structs.values()] + + for val in m.imported_modules.values(): + structs += _get_structs( + val.module_node._metadata["type"], prefix + "." + val.alias, visited + ) + + return structs + + def build_bb_output(compiler_data: CompilerData) -> IRnode: return compiler_data.venom_functions[0]