diff --git a/thrift/lib/python/types.pyx b/thrift/lib/python/types.pyx index e49189a52d3..4c25f7102ad 100644 --- a/thrift/lib/python/types.pyx +++ b/thrift/lib/python/types.pyx @@ -1504,22 +1504,22 @@ cdef class Union(StructOrUnion): Returns the value of the field with the given `field_id` if it is indeed the field that is (currently) set for this union. Otherwise, raises AttributeError. """ - if self.type.value != field_id: + if _fbthrift_get_Union_type_int(self) != field_id: # TODO in python 3.10 update this to use name and obj fields raise AttributeError( f'Union contains a value of type {self.type.name}, not ' f'{type(self).Type(field_id).name}') return self.value - def get_type(self): + def get_type(Union self not None): return self.type @property - def fbthrift_current_field(self): + def fbthrift_current_field(Union self not None): return self.type @property - def fbthrift_current_value(self): + def fbthrift_current_value(Union self not None): return self.value @classmethod @@ -1555,29 +1555,39 @@ cdef class Union(StructOrUnion): def __deepcopy__(Union self, _memo): return self - def __eq__(Union self, other): + def __eq__(Union self not None, other): if type(other) != type(self): return False - return self.type == other.type and self.value == other.value + cdef Union other_u = other + cdef int self_type_int = _fbthrift_get_Union_type_int(self) + cdef int other_type_int = _fbthrift_get_Union_type_int(other_u) + return self_type_int == other_type_int and self.value == other_u.value - def __lt__(self, other): + def __lt__(Union self not None, other): if type(self) != type(other): return NotImplemented - return (self.type.value, self.value) < (other.type.value, other.value) + cdef Union other_u = other + cdef int self_type_int = _fbthrift_get_Union_type_int(self) + cdef int other_type_int = _fbthrift_get_Union_type_int(other_u) + return (self_type_int, self.value) < (other_type_int, other_u.value) - def __le__(self, other): + def __le__(Union self not None, other): if type(self) != type(other): return NotImplemented - return (self.type.value, self.value) <= (other.type.value, other.value) + cdef Union other_u = other + cdef int self_type_int = _fbthrift_get_Union_type_int(self) + cdef int other_type_int = _fbthrift_get_Union_type_int(other_u) + return (self_type_int, self.value) <= (other_type_int, other_u.value) - def __hash__(self): - return hash((self.type, self.value)) + def __hash__(Union self not None): + cdef int self_type_int = _fbthrift_get_Union_type_int(self) + return hash((self_type_int, self.value)) def __repr__(self): return f"{type(self).__name__}({self.type.name}={self.value!r})" - def __bool__(self): - return self.type.value != 0 + def __bool__(self not None): + return _fbthrift_get_Union_type_int(self) != 0 def __dir__(self): return dir(type(self)) @@ -1585,6 +1595,10 @@ cdef class Union(StructOrUnion): def __reduce__(self): return (_unpickle_union, (type(self), b''.join(self._serialize(Protocol.COMPACT)))) + +cdef inline _fbthrift_get_Union_type_int(Union u): + return u._fbthrift_data[0] + cdef _make_fget_struct(i): """ Returns a function that takes a `Struct` instance and returns the value of diff --git a/thrift/test/thrift-python/union_test.py b/thrift/test/thrift-python/union_test.py index 77785a10a31..563e3061e68 100644 --- a/thrift/test/thrift-python/union_test.py +++ b/thrift/test/thrift-python/union_test.py @@ -45,7 +45,7 @@ ) -def _thrift_serialization_round_trip( +def _assert_serialization_round_trip( test: unittest.TestCase, serializer_module: types.ModuleType, thrift_object: typing.Union[MutableStructOrUnion, ImmutableStructOrUnion], @@ -80,7 +80,7 @@ def test_creation(self) -> None: AttributeError, "Union contains a value of type EMPTY, not string_field" ): u.string_field - _thrift_serialization_round_trip(self, immutable_serializer, u) + _assert_serialization_round_trip(self, immutable_serializer, u) # Specifying exactly one keyword argument whose name corresponds to that of a # field for this Union, and a non-None value whose type is valid for that field, @@ -98,7 +98,7 @@ def test_creation(self) -> None: AttributeError, "Union contains a value of type string_field, not int_field" ): u2.int_field - _thrift_serialization_round_trip(self, immutable_serializer, u2) + _assert_serialization_round_trip(self, immutable_serializer, u2) # Attempts to initialize an instance with a keyword argument whose name does # not match that of a field should raise an error. @@ -250,7 +250,7 @@ def test_from_value_ambiguous_int_bool(self) -> None: ) self.assertEqual(union_int_bool_1.value, 1) self.assertEqual(union_int_bool_1.int_field, 1) - _thrift_serialization_round_trip(self, immutable_serializer, union_int_bool_1) + _assert_serialization_round_trip(self, immutable_serializer, union_int_bool_1) # BAD: fromValue(bool) populates an int field if it comes before bool. union_int_bool_2 = TestUnionAmbiguousFromValueIntBoolImmutable.fromValue(True) @@ -264,7 +264,7 @@ def test_from_value_ambiguous_int_bool(self) -> None: ) self.assertEqual(union_int_bool_2.value, 1) self.assertEqual(union_int_bool_2.int_field, 1) - _thrift_serialization_round_trip(self, immutable_serializer, union_int_bool_2) + _assert_serialization_round_trip(self, immutable_serializer, union_int_bool_2) def test_from_value_ambiguous_bool_int(self) -> None: # BAD: Unlike the previous test case, fromValue(int) does not populate @@ -281,7 +281,7 @@ def test_from_value_ambiguous_bool_int(self) -> None: self.assertEqual(union_bool_int_1.value, 1) self.assertEqual(union_bool_int_1.int_field, 1) self.assertEqual(union_bool_int_1.int_field, True) - _thrift_serialization_round_trip(self, immutable_serializer, union_bool_int_1) + _assert_serialization_round_trip(self, immutable_serializer, union_bool_int_1) union_bool_int_2 = TestUnionAmbiguousFromValueBoolIntImmutable.fromValue(True) self.assertIs( @@ -295,7 +295,7 @@ def test_from_value_ambiguous_bool_int(self) -> None: self.assertEqual(union_bool_int_2.value, True) self.assertEqual(union_bool_int_2.value, 1) self.assertEqual(union_bool_int_2.bool_field, 1) - _thrift_serialization_round_trip(self, immutable_serializer, union_bool_int_2) + _assert_serialization_round_trip(self, immutable_serializer, union_bool_int_2) def test_from_value_ambiguous_float_int(self) -> None: # BAD: fromValue(int) populated a float field if it comes before int. @@ -310,7 +310,7 @@ def test_from_value_ambiguous_float_int(self) -> None: ) self.assertEqual(union_float_int_1.value, 1.0) self.assertEqual(union_float_int_1.float_field, 1) - _thrift_serialization_round_trip(self, immutable_serializer, union_float_int_1) + _assert_serialization_round_trip(self, immutable_serializer, union_float_int_1) union_float_int_2 = TestUnionAmbiguousFromValueFloatIntImmutable.fromValue(1.0) self.assertIs( @@ -323,7 +323,7 @@ def test_from_value_ambiguous_float_int(self) -> None: ) self.assertEqual(union_float_int_2.value, 1.0) self.assertEqual(union_float_int_2.float_field, 1) - _thrift_serialization_round_trip(self, immutable_serializer, union_float_int_2) + _assert_serialization_round_trip(self, immutable_serializer, union_float_int_2) def test_field_name_conflict(self) -> None: # By setting class type `Type` attr after field attrs, we get the desired behavior @@ -355,20 +355,18 @@ def test_field_name_conflict(self) -> None: ): # pyre-ignore[41]: Intentional for test type_union.Type = 1 - _thrift_serialization_round_trip(self, immutable_serializer, type_union) + _assert_serialization_round_trip(self, immutable_serializer, type_union) u = TestUnionAmbiguousValueFieldNameImmutable(value=42) self.assertEqual(u.value, 42) with self.assertRaises(AttributeError): u.type - with self.assertRaises(AttributeError): - _thrift_serialization_round_trip(self, immutable_serializer, u) + _assert_serialization_round_trip(self, immutable_serializer, u) u2 = TestUnionAmbiguousValueFieldNameImmutable(type=123) with self.assertRaises(AttributeError): u2.value - with self.assertRaises(AssertionError): - _thrift_serialization_round_trip(self, immutable_serializer, u2) + _assert_serialization_round_trip(self, immutable_serializer, u2) def test_hash(self) -> None: hash(TestUnionImmutable()) @@ -378,13 +376,13 @@ def test_equality(self) -> None: u2 = TestUnionImmutable(string_field="hello") self.assertIsNot(u1, u2) self.assertEqual(u1, u2) - _thrift_serialization_round_trip(self, immutable_serializer, u1) - _thrift_serialization_round_trip(self, immutable_serializer, u2) + _assert_serialization_round_trip(self, immutable_serializer, u1) + _assert_serialization_round_trip(self, immutable_serializer, u2) u3 = TestUnionImmutable(string_field="world") self.assertIsNot(u1, u3) self.assertNotEqual(u1, u3) - _thrift_serialization_round_trip(self, immutable_serializer, u3) + _assert_serialization_round_trip(self, immutable_serializer, u3) def test_ordering(self) -> None: self.assertLess( @@ -408,7 +406,7 @@ def test_adapted_types(self) -> None: ) self.assertIs(u1.type, TestUnionAdaptedTypesImmutable.Type.EMPTY) self.assertIsNone(u1.value) - _thrift_serialization_round_trip(self, immutable_serializer, u1) + _assert_serialization_round_trip(self, immutable_serializer, u1) with self.assertRaisesRegex( AttributeError, @@ -465,8 +463,8 @@ def test_adapted_types(self) -> None: with self.assertRaisesRegex( AttributeError, "'str' object has no attribute 'timestamp'" ): - (TestUnionAdaptedTypesImmutable.fromValue("1718728839"),) - _thrift_serialization_round_trip(self, immutable_serializer, u2) + TestUnionAdaptedTypesImmutable.fromValue("1718728839") + _assert_serialization_round_trip(self, immutable_serializer, u2) u3 = TestUnionAdaptedTypesImmutable(non_adapted_i32=1718728839) self.assertIs( @@ -476,7 +474,7 @@ def test_adapted_types(self) -> None: self.assertIs(u3.type, TestUnionAdaptedTypesImmutable.Type.non_adapted_i32) self.assertIs(u3.value, u3.non_adapted_i32) self.assertEqual(u3.non_adapted_i32, 1718728839) - _thrift_serialization_round_trip(self, immutable_serializer, u3) + _assert_serialization_round_trip(self, immutable_serializer, u3) def test_to_immutable_python(self) -> None: union_immutable = TestUnionImmutable(string_field="hello")