diff --git a/setup.cfg b/setup.cfg index db6e196..2540c19 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,7 @@ packages=find: python_requires = >=3.10 install_requires = ethereum-types>=0.2.1,<0.3 + typing-extensions>=4.12.2 [options.packages.find] where=src diff --git a/src/ethereum_rlp/__init__.py b/src/ethereum_rlp/__init__.py index 02bb945..4bcf630 100644 --- a/src/ethereum_rlp/__init__.py +++ b/src/ethereum_rlp/__init__.py @@ -4,4 +4,4 @@ from .rlp import RLP, Extended, Simple, decode, decode_to, encode # noqa: F401 -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/src/ethereum_rlp/exceptions.py b/src/ethereum_rlp/exceptions.py index ba3f61a..5349ddf 100644 --- a/src/ethereum_rlp/exceptions.py +++ b/src/ethereum_rlp/exceptions.py @@ -2,6 +2,8 @@ Exceptions that can be thrown while serializing/deserializing RLP. """ +from typing_extensions import override + class RLPException(Exception): """ @@ -14,6 +16,19 @@ class DecodingError(RLPException): Indicates that RLP decoding failed. """ + @override + def __str__(self) -> str: + message = [super().__str__()] + current: BaseException = self + while isinstance(current, DecodingError) and current.__cause__: + current = current.__cause__ + if isinstance(current, DecodingError): + as_str = super(DecodingError, current).__str__() + else: + as_str = str(current) + message.append(f"\tbecause {as_str}") + return "\n".join(message) + class EncodingError(RLPException): """ diff --git a/src/ethereum_rlp/rlp.py b/src/ethereum_rlp/rlp.py index e711605..0b7c64e 100644 --- a/src/ethereum_rlp/rlp.py +++ b/src/ethereum_rlp/rlp.py @@ -152,7 +152,10 @@ def decode_to(cls: Type[U], encoded_data: Bytes) -> U: a `Bytes` subclass, a dataclass, `Uint`, `U256` or `Tuple[cls]`. """ decoded = decode(encoded_data) - return _deserialize_to(cls, decoded) + try: + return _deserialize_to(cls, decoded) + except Exception as e: + raise DecodingError(f"cannot decode into `{cls.__name__}`") from e @overload @@ -200,7 +203,11 @@ def _deserialize_to_dataclass(cls: Type[U], decoded: Simple) -> U: for value, target_field in zip(decoded, target_fields): resolved_type = hints[target_field.name] - values[target_field.name] = _deserialize_to(resolved_type, value) + try: + values[target_field.name] = _deserialize_to(resolved_type, value) + except Exception as e: + msg = f"cannot decode field `{cls.__name__}.{target_field.name}`" + raise DecodingError(msg) from e result = cls(**values) assert isinstance(result, cls) @@ -286,8 +293,13 @@ def _deserialize_to_tuple( arguments = list(arguments) + [arguments[-1]] * fill_count decoded = [] - for argument, value in zip(arguments, values): - decoded.append(_deserialize_to(argument, value)) + for index, (argument, value) in enumerate(zip(arguments, values)): + try: + deserialized = _deserialize_to(argument, value) + except Exception as e: + msg = f"cannot decode tuple element {index} of type `{argument}`" + raise DecodingError(msg) from e + decoded.append(deserialized) return tuple(decoded) @@ -298,7 +310,15 @@ def _deserialize_to_list( if isinstance(values, bytes): raise DecodingError("invalid list") argument = get_args(annotation)[0] - return [_deserialize_to(argument, v) for v in values] + results = [] + for index, value in enumerate(values): + try: + deserialized = _deserialize_to(argument, value) + except Exception as e: + msg = f"cannot decode list item {index} of type `{annotation}`" + raise DecodingError(msg) from e + results.append(deserialized) + return results def decode_to_bytes(encoded_bytes: Bytes) -> Bytes: diff --git a/tests/test_rlp.py b/tests/test_rlp.py index 7e9a262..4712934 100644 --- a/tests/test_rlp.py +++ b/tests/test_rlp.py @@ -300,7 +300,7 @@ class WithInt: def test_decode_to__int() -> None: - with pytest.raises(NotImplementedError): + with pytest.raises(DecodingError): rlp.decode_to(WithInt, b"\xc1\x00") @@ -414,13 +414,13 @@ class WithNonRlp: def test_decode_to__annotation_non_rlp() -> None: - with pytest.raises(NotImplementedError, match="RLP non-type"): + with pytest.raises(DecodingError, match="RLP non-type"): rlp.decode_to(WithNonRlp, b"\xc2\xc1\x01") @dataclass class WithList: - items: List[Uint] + items: List[Union[Bytes1, Bytes4]] def test_decode_to__list_bytes() -> None: @@ -428,6 +428,11 @@ def test_decode_to__list_bytes() -> None: rlp.decode_to(WithList, b"\xc1\x80") +def test_decode_to__list_invalid_union() -> None: + with pytest.raises(DecodingError, match="list item 0"): + rlp.decode_to(WithList, b"\xc2\xc1\xc0") + + # # Testing uint decoding #