Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[ux]: include imported structs in interface output #4362

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions tests/functional/codegen/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,122 @@ def foo(s: MyStruct) -> MyStruct:
out = compile_code(code, contract_path="code.vy", output_formats=["interface"])["interface"]

assert "# Structs" in out
assert "struct MyStruct:" in out
assert "MyStruct:" in out
assert "b: uint256" in out
assert "struct Voter:" in out
assert "Voter:" in out
assert "voted: bool" in out


def test_interface_with_imported_structures(make_input_bundle):
a = """
import b

struct Foo:
val:uint256
"""
b = """
import c

struct Bar:
val:uint256
"""
c = """
struct Baz:
val:uint256
"""
input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c})
out = compile_code(
a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"]
)["interface"]

assert "# Structs" in out
assert "Foo:" in out
assert "b Bar:" in out
assert "b.c Baz:" in out


def test_interface_with_doubly_imported_structure(make_input_bundle):
a = """
import b
import c

struct Foo:
val:uint256
"""
b = """
import d

struct Bar:
val:uint256
"""
c = """
import d
struct Baz:
val:uint256
"""
d = """
struct Boo:
val:uint256
"""

input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c, "d.vy": d})
out = compile_code(
a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"]
)["interface"]

assert "# Structs" in out
assert "Foo:" in out
assert "b Bar:" in out
assert "c Baz" in out
assert out.count("Boo") == 1


def test_interface_with_imported_struct_via_interface(make_input_bundle):
a = """
import b

struct Foo:
val:uint256
"""
b = """
struct Bar:
val:uint256

"""

input_bundle = make_input_bundle({"a.vy": a, "b.vyi": b})
out = compile_code(
a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"]
)["interface"]
print(out)
assert "# Structs" in out
assert "Foo:" in out
assert "b Bar:" in out


def test_interface_with_imported_structs_via_interface(make_input_bundle):
a = """
import b
import c

struct Foo:
val:uint256
"""
b = """
struct Bar:
val:uint256
"""
c = """
struct Baz:
val:uint256
"""

input_bundle = make_input_bundle({"a.vy": a, "b.vyi": b, "c.vy": c})
out = compile_code(
a, input_bundle=input_bundle, contract_path="a.vy", output_formats=["interface"]
)["interface"]
print(out)
assert "# Structs" in out
assert "Foo:" in out
assert "b Bar:" in out
assert "c Baz:" in out
31 changes: 26 additions & 5 deletions vyper/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vyper.ir import compile_ir
from vyper.semantics.analysis.base import ModuleInfo
from vyper.semantics.types.function import FunctionVisibility, StateMutability
from vyper.semantics.types.module import InterfaceT
from vyper.semantics.types.module import InterfaceT, ModuleT, StructT
from vyper.typing import StorageLayout
from vyper.utils import vyper_warn
from vyper.warnings import ContractSizeLimitWarning
Expand Down Expand Up @@ -125,13 +125,15 @@ def build_external_interface_output(compiler_data: CompilerData) -> str:


def build_interface_output(compiler_data: CompilerData) -> str:
interface = compiler_data.annotated_vyper_module._metadata["type"].interface
module_t = compiler_data.annotated_vyper_module._metadata["type"]
interface = module_t.interface
out = ""

if len(interface.structs) > 0:
structs = _get_structs(module_t)
if len(structs) > 0:
out += "# Structs\n\n"
for struct in interface.structs.values():
out += f"struct {struct.name}:\n"
for prefix, struct in structs:
out += f"struct " f"struct {prefix[1:] + ' ' if prefix else ''}{struct.name}:\n"
for member_name, member_type in struct.members.items():
out += f" {member_name}: {member_type}\n"
out += "\n\n"
Expand Down Expand Up @@ -159,6 +161,25 @@ def build_interface_output(compiler_data: CompilerData) -> str:
return out


def _get_structs(m: ModuleT, prefix="", visited: set[ModuleT] = None) -> list:
visited = visited or set()
if m in visited:
return []
visited.add(m)

structs = [(prefix, val) for val in m.interface.structs.values()]

for alias, interface in m.interfaces.items():
structs += [(prefix + "." + alias, val) for val in interface.structs.values()]

for val in m.imported_modules.values():
structs += _get_structs(
val.module_node._metadata["type"], prefix + "." + val.alias, visited
)

return structs


def build_bb_output(compiler_data: CompilerData) -> IRnode:
return compiler_data.venom_functions[0]

Expand Down