diff --git a/.flake8 b/.flake8 index 3a29203a..ca9036c7 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ ; Copyright (C) 2023 Antmicro ; SPDX-License-Identifier: Apache-2.0 [flake8] -ignore = E203, E266, E501, W503, F403, F401 +ignore = E203, E266, E501, W503, F403, F401, F405 max-line-length = 100 max-complexity = 27 select = B,C,E,F,W,T4,B9 diff --git a/docs/source/elaboratable_wrapper.md b/docs/source/elaboratable_wrapper.md new file mode 100644 index 00000000..4033430b --- /dev/null +++ b/docs/source/elaboratable_wrapper.md @@ -0,0 +1,12 @@ +# ElaboratableWrapper class + +{class}`ElaboratableWrapper` encapsulates an Amaranth's Elaboratable and exposes an interface compatible with other wrappers which allows making connections with them. +Supplied elaboratable must contain a `signature` property and a conforming interface as specified by [Amaranth docs](https://amaranth-lang.org/rfcs/0002-interfaces.html). +Ports' directionality, their names and widths are inferred from it. + +```{eval-rst} +.. autoclass:: fpga_topwrap.elaboratable_wrapper.ElaboratableWrapper + :members: + + .. automethod:: __init__ +``` diff --git a/docs/source/index.md b/docs/source/index.md index fc9ced55..1523a538 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -11,6 +11,7 @@ pipeline_manager wrapper_port ipwrapper ipconnect +elaboratable_wrapper helpers fusesoc ``` diff --git a/fpga_topwrap/elaboratable_wrapper.py b/fpga_topwrap/elaboratable_wrapper.py new file mode 100644 index 00000000..92fedbb7 --- /dev/null +++ b/fpga_topwrap/elaboratable_wrapper.py @@ -0,0 +1,174 @@ +# Copyright (C) 2023 Antmicro +# SPDX-License-Identifier: Apache-2.0 + +from functools import cache +from typing import Iterable, Mapping, Union + +from amaranth import * +from amaranth.build import Platform +from amaranth.hdl.ast import Assign, Shape +from amaranth.lib import wiring + +from .amaranth_helpers import DIR_IN, DIR_OUT, WrapperPort +from .wrapper import Wrapper + +SignalMapping = Mapping[str, Union[Signal, "SignalMapping"]] +InterfaceLike = Union[wiring.PureInterface, Elaboratable] + + +class ElaboratableWrapper(Wrapper): + """Allows connecting an Amaranth's Elaboratable with other + classes derived from Wrapper. + """ + + def __init__(self, name: str, elaboratable: Elaboratable) -> None: + """ + :param name: name of this wrapper + :param elaboratable: Amaranth's Elaboratable object to wrap + """ + super().__init__(name) + self.elaboratable = elaboratable + self.clk = self._cached_wrapper( + port_width=1, port_flow=wiring.In, name="clk", port_name="clk", iface_name="" + ) + self.rst = self._cached_wrapper( + port_width=1, port_flow=wiring.In, name="rst", port_name="rst", iface_name="" + ) + + def get_ports(self) -> list[WrapperPort]: + """Return a list of external ports.""" + return self._flatten_hier(self.get_ports_hier()) + + def get_ports_hier(self) -> SignalMapping: + """Maps elaboratable's Signature to a nested dictionary of WrapperPorts. + See _gather_signature_ports for more details. + """ + return self._gather_signature_ports(self.elaboratable.signature) | { + "clk": self.clk, + "rst": self.rst, + } + + @cache + def _cached_wrapper( + self, port_width: int, port_flow: wiring.Flow, name: str, port_name: str, iface_name: str + ) -> WrapperPort: + """Constructs a WrapperPort, but only one instance per set of parameters in + a module is ever created. Multiple calls to this function with the identical + parameters return the same object. + + :param port_width: width of the port + :param port_flow: directionality of the port, one of: wiring.In, wiring.Out + :param name: name of the port + :param port_name: original port name as it appears in the signature + :param iface_name: name of the interface the ports belongs to + """ + return WrapperPort( + bounds=[port_width - 1, 0, port_width - 1, 0], + name=name, + internal_name=port_name, + interface_name=iface_name, + direction=DIR_IN if port_flow == wiring.In else DIR_OUT, + ) + + def _gather_signature_ports( + self, signature: wiring.Signature, prefix: str = "" + ) -> SignalMapping: + """Maps a signature to a nested dictionary of WrapperPorts. + For example, an elaboratable with this signature: + + Signature({ + "data": Out(Signature({ + "payload": Out(7), + "chksum": Out(1) + })), + "ready": In(1), + "valid": Out(1) + }) + + Translates to this dictionary structure (some details omitted for clarity): + + { + "data": { + "payload": WrapperPort( + bounds=[6, 0, 6, 0], + name="data_payload", + internal_name="payload", + interface_name="data", + direction=DIR_OUT + ), + "chksum": WrapperPort(...) + }, + "ready": WrapperPort( + bounds=[0, 0, 0, 0], + name="ready", + internal_name="ready", + interface_name="", + direction=DIR_IN + ), + "valid": WrapperPort(...) + } + + :param signature: Amaranth's Signature to map to a dictionary + :param prefix: optional interface prefix to prepend to the name of all ports + """ + iface = {} + for port_name, port in signature.members.items(): + name = f"{prefix}_{port_name}" if prefix else port_name + if port.is_signature: + inner_iface = self._gather_signature_ports(port.signature, prefix=name) + iface[port_name] = inner_iface + else: + iface[port_name] = self._cached_wrapper( + Shape.cast(port.shape).width, port.flow, name, port_name, prefix + ) + return iface + + def _flatten_hier(self, hier: SignalMapping) -> Iterable[Signal]: + """Flattens a nested dictionary with WrapperPorts. + + :param hier: a (nested) dictionary of WrapperPorts + """ + ports = [] + try: + for _, port in hier.items(): + ports += self._flatten_hier(port) + except AttributeError: + ports += [hier] + return ports + + def _connect_ports(self, ports: SignalMapping, iface: InterfaceLike) -> list[Assign]: + """Returns a list of amaranth assignments between the wrapped elaboratable and external ports. + + :param ports: nested dictionary of WrapperPorts mirroring that of iface's signature + :param iface: Amaranth Interface to make connections with + """ + conns = [] + for port_name, port in iface.signature.members.items(): + iface_port = getattr(iface, port_name) + if port.is_signature: + conns += self._connect_ports(ports[port_name], iface_port) + else: + if port.flow == wiring.In: + conns.append(iface_port.eq(ports[port_name])) + elif port.flow == wiring.Out: + conns.append(ports[port_name].eq(iface_port)) + else: + raise TypeError(f"Invalid InOut flow direction in signal '{port_name}'") + return conns + + def elaborate(self, platform: Platform) -> Module: + m = Module() + + # create an internal clock domain that doesn't propagate upwards in the submodule + # tree and assign clk and rst specified by the user to the internal domain signals + cd = ClockDomain(self.name, local=True) + m.d.comb += ClockSignal(self.name).eq(self.clk) + m.d.comb += ResetSignal(self.name).eq(self.rst) + m.domains += cd + + # make the elaboratable use the new clock domain internally + m.submodules += DomainRenamer(self.name)(self.elaboratable) + + m.d.comb += self._connect_ports(self.get_ports_hier(), self.elaboratable) + + return m diff --git a/fpga_topwrap/wrapper.py b/fpga_topwrap/wrapper.py new file mode 100644 index 00000000..e9dc00b3 --- /dev/null +++ b/fpga_topwrap/wrapper.py @@ -0,0 +1,48 @@ +# Copyright (C) 2021 Antmicro +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +from amaranth import * + +from .amaranth_helpers import WrapperPort + + +class Wrapper(Elaboratable): + """Base class for modules that want to connect to each other. + + Derived classes must implement get_ports method that returns + a list of WrapperPort's - external ports of a class that can + be used as endpoints for connections. + """ + + def __init__(self, name: str) -> None: + self.name = name + + @property + def get_ports(self) -> List[WrapperPort]: + """Return a list of external ports.""" + raise NotImplementedError('Derived classes must implement "get_ports" method') + + def get_port_by_name(self, name: str) -> WrapperPort: + """Given port's name, return the port as WrapperPort object. + + :raises ValueError: If such port doesn't exist. + """ + try: + port = {signal.name: signal for signal in self.get_ports()}[name] + except KeyError: + raise ValueError(f"Port named '{name}' couldn't be found in the hierarchy: {self.name}") + return port + + def get_ports_of_interface(self, iface_name: str) -> List[WrapperPort]: + """Return a list of ports of specific interface. + + :raises ValueError: if such interface doesn't exist. + """ + ports = [ + port for port in filter(lambda x: x.interface_name == iface_name, self.get_ports()) + ] + if not ports: + raise ValueError(f"No ports could be found for this interface name: {iface_name}") + return ports diff --git a/tests/tests_build/test_elaboratable_wrapper.py b/tests/tests_build/test_elaboratable_wrapper.py new file mode 100644 index 00000000..f43873c4 --- /dev/null +++ b/tests/tests_build/test_elaboratable_wrapper.py @@ -0,0 +1,336 @@ +# amaranth: UnusedElaboratable=no + +# Copyright (C) 2023 Antmicro +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Union + +import pytest +from amaranth import * +from amaranth.hdl.ast import Assign +from amaranth.lib.wiring import Component, In, Out, Signature + +from fpga_topwrap.amaranth_helpers import DIR_IN, DIR_OUT, WrapperPort +from fpga_topwrap.elaboratable_wrapper import ElaboratableWrapper, SignalMapping + + +@pytest.fixture +def elaboratable_name() -> str: + return "test_elaboratable" + + +@pytest.fixture +def stream_signature() -> Signature: + return Signature( + { + "data": Out(16), + "valid": Out(1), + "ready": In(1), + } + ) + + +@pytest.fixture +def region_signature() -> Signature: + return Signature( + { + "start": In(32), + "end": In(32), + } + ) + + +@pytest.fixture +def nested_signature(stream_signature: Signature, region_signature: Signature) -> Signature: + return Signature( + { + "o": Out(stream_signature), + "i": In(Signature({"meta": Out(4), "ok": In(1), "region": Out(region_signature)})), + } + ) + + +@pytest.fixture +def nested_signature_mapping() -> SignalMapping: + # Note: if parent Signature is marked as "In" then direction of its children + # is flipped. This is consistent with Amaranth's convention, see: + # https://amaranth-lang.org/rfcs/0002-interfaces.html#guide-level-explanation + return { + "o": { + "data": WrapperPort( + bounds=[15, 0, 15, 0], + name="o_data", + internal_name="data", + interface_name="o", + direction=DIR_OUT, + ), + "valid": WrapperPort( + bounds=[0, 0, 0, 0], + name="o_valid", + internal_name="valid", + interface_name="o", + direction=DIR_OUT, + ), + "ready": WrapperPort( + bounds=[0, 0, 0, 0], + name="o_ready", + internal_name="ready", + interface_name="o", + direction=DIR_IN, + ), + }, + "i": { + "meta": WrapperPort( + bounds=[3, 0, 3, 0], + name="i_meta", + internal_name="meta", + interface_name="i", + direction=DIR_IN, + ), + "ok": WrapperPort( + bounds=[0, 0, 0, 0], + name="i_ok", + internal_name="ok", + interface_name="i", + direction=DIR_OUT, + ), + "region": { + "start": WrapperPort( + bounds=[31, 0, 31, 0], + name="i_region_start", + internal_name="start", + interface_name="i_region", + direction=DIR_OUT, + ), + "end": WrapperPort( + bounds=[31, 0, 31, 0], + name="i_region_end", + internal_name="end", + interface_name="i_region", + direction=DIR_OUT, + ), + }, + }, + } + + +@pytest.fixture +def flattened_nested_signature_mapping() -> list[WrapperPort]: + return [ + WrapperPort( + bounds=[15, 0, 15, 0], + name="o_data", + internal_name="data", + interface_name="o", + direction=DIR_OUT, + ), + WrapperPort( + bounds=[0, 0, 0, 0], + name="o_valid", + internal_name="valid", + interface_name="o", + direction=DIR_OUT, + ), + WrapperPort( + bounds=[0, 0, 0, 0], + name="o_ready", + internal_name="ready", + interface_name="o", + direction=DIR_IN, + ), + WrapperPort( + bounds=[3, 0, 3, 0], + name="i_meta", + internal_name="meta", + interface_name="i", + direction=DIR_IN, + ), + WrapperPort( + bounds=[0, 0, 0, 0], + name="i_ok", + internal_name="ok", + interface_name="i", + direction=DIR_OUT, + ), + WrapperPort( + bounds=[31, 0, 31, 0], + name="i_region_start", + internal_name="start", + interface_name="i_region", + direction=DIR_OUT, + ), + WrapperPort( + bounds=[31, 0, 31, 0], + name="i_region_end", + internal_name="end", + interface_name="i_region", + direction=DIR_OUT, + ), + ] + + +@pytest.fixture +def elaboratable(nested_signature: Signature) -> Elaboratable: + class TestModule(Component): + def __init__(self): + super().__init__(signature=nested_signature) + + def elaborate(self): + m = Module() + return m + + return TestModule() + + +@pytest.fixture +def elaboratable_wrapper(elaboratable_name: str, elaboratable: Elaboratable) -> ElaboratableWrapper: + return ElaboratableWrapper(elaboratable_name, elaboratable) + + +@pytest.fixture +def make_cached_wrapper_port1(elaboratable_wrapper: ElaboratableWrapper) -> WrapperPort: + return lambda: elaboratable_wrapper._cached_wrapper( + port_width=12, + port_flow=In, + name="wrapper_port1", + port_name="port1", + iface_name="sample_iface", + ) + + +@pytest.fixture +def make_cached_wrapper_port2(elaboratable_wrapper: ElaboratableWrapper) -> WrapperPort: + return lambda: elaboratable_wrapper._cached_wrapper( + port_width=13, + port_flow=In, + name="wrapper_port2", + port_name="port2", + iface_name="sample_iface", + ) + + +@pytest.fixture +def clock_domain_signals(elaboratable_wrapper: ElaboratableWrapper) -> SignalMapping: + return {"clk": elaboratable_wrapper.clk, "rst": elaboratable_wrapper.rst} + + +@pytest.fixture +def interface_connections( + elaboratable: Elaboratable, nested_signature_mapping: SignalMapping +) -> list[tuple[Signal, Signal]]: + m = elaboratable + d = nested_signature_mapping + return [ + (d["o"]["data"], m.o.data), + (d["o"]["valid"], m.o.valid), + (m.o.ready, d["o"]["ready"]), + (m.i.meta, d["i"]["meta"]), + (d["i"]["ok"], m.i.ok), + (d["i"]["region"]["start"], m.i.region.start), + (d["i"]["region"]["end"], m.i.region.end), + ] + + +def wrapper_port_eq(p1: WrapperPort, p2: WrapperPort) -> bool: + for attr in ["bounds", "name", "internal_name", "interface_name", "direction"]: + if getattr(p1, attr) != getattr(p2, attr): + return False + return True + + +def signal_mapping_elem_eq( + v1: Union[SignalMapping, WrapperPort], v2: Union[SignalMapping, WrapperPort] +) -> bool: + return (isinstance(v1, dict) and isinstance(v2, dict) and signal_mapping_eq(v1, v2)) or ( + isinstance(v1, WrapperPort) and isinstance(v2, WrapperPort) and wrapper_port_eq(v1, v2) + ) + + +def signal_mapping_eq(d1: SignalMapping, d2: SignalMapping) -> bool: + if d1.keys() != d2.keys(): + return False + + for k, v1 in d1.items(): + v2 = d2[k] + if not signal_mapping_elem_eq(v1, v2): + return False + return True + + +class TestElaboratableWrapper: + def test_gather_signature_ports( + self, elaboratable_wrapper: ElaboratableWrapper, nested_signature_mapping: SignalMapping + ) -> None: + sig = elaboratable_wrapper.elaboratable.signature + sig_dict = elaboratable_wrapper._gather_signature_ports(sig) + assert signal_mapping_eq(sig_dict, nested_signature_mapping) + + def test_flatten_hier( + self, + elaboratable_wrapper: ElaboratableWrapper, + nested_signature_mapping: SignalMapping, + flattened_nested_signature_mapping: list[WrapperPort], + ) -> None: + def ordering(p): + return p.name + + flattened_hier = sorted( + elaboratable_wrapper._flatten_hier(nested_signature_mapping), key=ordering + ) + expected_hier = sorted(flattened_nested_signature_mapping, key=ordering) + for port_test, port_expect in zip(flattened_hier, expected_hier): + assert wrapper_port_eq(port_expect, port_test) + + def test_cached_wrapper( + self, + make_cached_wrapper_port1: Callable[[None], WrapperPort], + make_cached_wrapper_port2: Callable[[None], WrapperPort], + ) -> None: + port1 = make_cached_wrapper_port1() + port1_another = make_cached_wrapper_port1() + port2 = make_cached_wrapper_port2() + assert port1 is port1_another # pointer equality + assert not wrapper_port_eq(port1, port2) # structural inequality + + def test_get_ports_hier( + self, + elaboratable_wrapper: ElaboratableWrapper, + nested_signature_mapping: SignalMapping, + clock_domain_signals: SignalMapping, + ) -> None: + hier_ports = elaboratable_wrapper.get_ports_hier() + assert signal_mapping_eq(hier_ports, nested_signature_mapping | clock_domain_signals) + + def test_get_ports( + self, + elaboratable_wrapper: ElaboratableWrapper, + flattened_nested_signature_mapping: list[WrapperPort], + clock_domain_signals: SignalMapping, + ) -> None: + def ordering(p): + return p.name + + ports = sorted(elaboratable_wrapper.get_ports(), key=ordering) + expected_ports = sorted( + flattened_nested_signature_mapping + list(clock_domain_signals.values()), key=ordering + ) + for port_test, port_expect in zip(ports, expected_ports): + assert wrapper_port_eq(port_expect, port_test) + + def test_connect_ports( + self, + elaboratable_wrapper: ElaboratableWrapper, + nested_signature_mapping: SignalMapping, + elaboratable: Elaboratable, + interface_connections: list[tuple[Signal, Signal]], + ) -> None: + conns_test = sorted( + elaboratable_wrapper._connect_ports(nested_signature_mapping, elaboratable), + key=lambda conn: conn.lhs.name, + ) + conns_expect = sorted(interface_connections, key=lambda conn: conn[0].name) + for conn_test, conn_expect in zip(conns_test, conns_expect): + lhs, rhs = conn_expect + assert isinstance(conn_test, Assign) + assert lhs is conn_test.lhs # pointer equality + assert rhs is conn_test.rhs # pointer equality