diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index 608783000a..0028a8ca56 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -13,7 +13,7 @@ ZeroDivisionException, ) from vyper.semantics.types import IntegerT -from vyper.utils import evm_div, evm_mod +from vyper.utils import evm_div, evm_mod, signed_to_unsigned types = sorted(IntegerT.signeds()) @@ -253,8 +253,9 @@ def num_sub() -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) +@pytest.mark.parametrize("is_hex_int", [True, False]) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, tx_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, op, typ, is_hex_int): # both variables code_1 = f""" @external @@ -329,9 +330,17 @@ def foo() -> {typ}: ok = in_bounds and not div_by_zero - code_2 = code_2_template.format(typ=typ, op=op, y=y) - code_3 = code_3_template.format(typ=typ, op=op, x=x) - code_4 = code_4_template.format(typ=typ, op=op, x=x, y=y) + formatted_x = x + formatted_y = y + + if is_hex_int: + n_nibbles = typ.bits // 4 + formatted_x = "0x" + hex(signed_to_unsigned(x, typ.bits))[2:].rjust(n_nibbles, "0") + formatted_y = "0x" + hex(signed_to_unsigned(y, typ.bits))[2:].rjust(n_nibbles, "0") + + code_2 = code_2_template.format(typ=typ, op=op, y=formatted_y) + code_3 = code_3_template.format(typ=typ, op=op, x=formatted_x) + code_4 = code_4_template.format(typ=typ, op=op, x=formatted_x, y=formatted_y) if ok: assert c.foo(x, y) == expected diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 2bd3184ec0..d49256d46b 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -210,7 +210,8 @@ def foo(x: {typ}, y: {typ}) -> bool: @pytest.mark.parametrize("typ", types) -def test_uint_literal(get_contract, typ): +@pytest.mark.parametrize("is_hex_int", [True, False]) +def test_uint_literal(get_contract, typ, is_hex_int): lo, hi = typ.ast_bounds good_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 1, hi] @@ -223,10 +224,16 @@ def test() -> {typ}: """ for val in good_cases: - c = get_contract(code_template.format(typ=typ, val=val)) + input_val = val + if is_hex_int: + n_nibbles = typ.bits // 4 + input_val = "0x" + hex(val)[2:].rjust(n_nibbles, "0") + c = get_contract(code_template.format(typ=typ, val=input_val)) assert c.test() == val for val in bad_cases: + if is_hex_int: + return exc = ( TypeMismatch if SizeLimits.MIN_INT256 <= val <= SizeLimits.MAX_UINT256 diff --git a/tests/functional/syntax/test_list.py b/tests/functional/syntax/test_list.py index e55b060542..8bdecc527d 100644 --- a/tests/functional/syntax/test_list.py +++ b/tests/functional/syntax/test_list.py @@ -85,7 +85,7 @@ def foo(x: int128[2][2]): def foo(): self.bar = [1, 2, 0x1234567890123456789012345678901234567890] """, - InvalidLiteral, + TypeMismatch, ), ( """ @@ -309,6 +309,12 @@ def foo(): for i: DynArray[uint256, 3] in [[], []]: x = i """, + """ +bar: uint160[3] +@external +def foo(): + self.bar = [1, 2, 0x1234567890123456789012345678901234567890] + """, ] diff --git a/tests/functional/syntax/test_send.py b/tests/functional/syntax/test_send.py index ffad1b3792..c9ed69937f 100644 --- a/tests/functional/syntax/test_send.py +++ b/tests/functional/syntax/test_send.py @@ -22,14 +22,6 @@ def foo(): ), ( """ -@external -def foo(): - send(0x1234567890123456789012345678901234567890, 0x1234567890123456789012345678901234567890) - """, - TypeMismatch, - ), - ( - """ x: int128 @external @@ -161,6 +153,11 @@ def foo(): def foo(): send(0xde0B295669a9FD93d5F28D9Ec85E40f4cb697BAe, 5, gas=self.x) """, + """ +@external +def foo(): + send(0x1234567890123456789012345678901234567890, 0x1234567890123456789012345678901234567890) + """, ] diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index aa9a702be3..3ad31d66f8 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -25,7 +25,7 @@ def foo(b: {value}): analyze_module(vyper_module) -@pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) +@pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "b'\x01'", "False"]) def test_invalid_literal(namespace, value): code = f""" diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 974685f403..daae397d4c 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -31,6 +31,7 @@ SizeLimits, annotate_source_code, evm_div, + hex_to_int, quantize, sha256sum, ) @@ -883,6 +884,13 @@ def bytes_value(self): """ return bytes.fromhex(self.value.removeprefix("0x")) + @property + def uint_value(self): + """ + This value as unsigned integer + """ + return hex_to_int(self.value) + class Str(Constant): __slots__ = () diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 783764271d..90497d1cc9 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -120,6 +120,8 @@ class Decimal(Num): ... class Hex(Num): @property def n_bytes(self): ... + @property + def uint_value(self): ... class Str(Constant): ... class Bytes(Constant): ... diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 62539872bc..228fbc8808 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2363,13 +2363,23 @@ def infer_kwarg_types(self, node): for kwarg in node.keywords: kwarg_name = kwarg.arg validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ) + if kwarg_name == "method_id": + # Filter out unsigned integer types in the case of hex integers + p_types = [ + t + for t in get_possible_types_from_node(kwarg.value) + if not isinstance(t, IntegerT) + ] + typ = p_types.pop() - typ = get_exact_type_from_node(kwarg.value) - if kwarg_name == "method_id" and isinstance(typ, BytesT): - if typ.length != 4: + if isinstance(typ, BytesT) and typ.length != 4: raise InvalidLiteral("method_id must be exactly 4 bytes!", kwarg.value) + else: + typ = get_exact_type_from_node(kwarg.value) + ret[kwarg_name] = typ + return ret def fetch_call_return(self, node): diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 3a09bbe6c0..f7e1cac9a2 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -16,6 +16,7 @@ is_array_like, is_bytes_m_type, is_flag_type, + is_integer_type, is_numeric_type, is_tuple_like, make_setter, @@ -66,6 +67,7 @@ bytes_to_int, is_checksum_encoded, string_to_bytes, + unsigned_to_signed, vyper_warn, ) @@ -133,6 +135,12 @@ def parse_Hex(self): return IRnode.from_list(val, typ=t) + elif is_integer_type(t): + val = self.expr.uint_value + if t.is_signed: + val = unsigned_to_signed(val, t.bits, strict=True) + return IRnode.from_list(val, typ=t) + # String literals def parse_Str(self): bytez, bytez_length = string_to_bytes(self.expr.value) diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index 98cab0f8cb..d6b99432eb 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -131,19 +131,37 @@ def visit_UnaryOp(self, node): def visit_BinOp(self, node): left, right = [i.get_folded_value() for i in (node.left, node.right)] - if type(left) is not type(right): + valid_integer_nodes = (vy_ast.Hex, vy_ast.Int) + dissimilar_integer_nodes = isinstance(left, valid_integer_nodes) and not isinstance(right, valid_integer_nodes) + dissimilar_decimal_nodes = isinstance(left, vy_ast.Decimal) and type(left) is not type(right) + if dissimilar_decimal_nodes and dissimilar_integer_nodes: raise UnfoldableNode("invalid operation", node) - if not isinstance(left, vy_ast.Num): + if not isinstance(left, (vy_ast.Hex, vy_ast.Num)): raise UnfoldableNode("not a number!", node.left) + l_val = left.value + r_val = right.value + + # hex literals default to unsigned values during constant folding + if isinstance(left, vy_ast.Hex): + l_val = left.uint_value + if isinstance(right, vy_ast.Hex): + r_val = right.uint_value + # this validation is performed to prevent the compiler from hanging # on very large shifts and improve the error message for negative # values. - if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= right.value <= 256): + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= r_val <= 256): raise InvalidLiteral("Shift bits must be between 0 and 256", node.right) - value = node.op._op(left.value, right.value) - return type(left).from_node(node, value=value) + value = node.op._op(l_val, r_val) + + new_node_type = type(left) + # fold hex integers into Int nodes + if isinstance(left, vy_ast.Hex): + new_node_type = vy_ast.Int + + return new_node_type.from_node(node, value=value) def visit_BoolOp(self, node): values = [v.get_folded_value() for v in node.values] diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 5c0362e662..55b0ed0ab0 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -13,7 +13,7 @@ OverflowException, VyperException, ) -from vyper.utils import checksum_encode, int_bounds, is_checksum_encoded +from vyper.utils import checksum_encode, int_bounds, is_checksum_encoded, unsigned_to_signed from .base import VyperType from .bytestrings import BytesT @@ -138,9 +138,19 @@ def is_signed(self) -> bool: def validate_literal(self, node: vy_ast.Constant) -> None: super().validate_literal(node) lower, upper = self.ast_bounds - if node.value < lower: + + value = node.value + if isinstance(node, vy_ast.Hex): + if node.value not in (node.value.lower(), node.value.upper()): + raise InvalidLiteral("Cannot mix uppercase and lowercase for hex integers", node) + + value = node.uint_value + if self.is_signed: + value = unsigned_to_signed(value, self.bits) + + if value < lower: raise OverflowException(f"Value is below lower bound for given type ({lower})", node) - if node.value > upper: + if value > upper: raise OverflowException(f"Value exceeds upper bound for given type ({upper})", node) def validate_numeric_op( @@ -242,7 +252,7 @@ class IntegerT(NumericT): typeclass = "integer" - _valid_literal = (vy_ast.Int,) + _valid_literal = (vy_ast.Hex, vy_ast.Int) _equality_attrs = ("is_signed", "bits") ast_type = int