diff --git a/thrift/lib/py3/test/auto_migrate/serializer.py b/thrift/lib/py3/test/auto_migrate/serializer.py index 932b412b8f9..cf5f961c0c7 100644 --- a/thrift/lib/py3/test/auto_migrate/serializer.py +++ b/thrift/lib/py3/test/auto_migrate/serializer.py @@ -167,13 +167,31 @@ def test_bad_deserialize(self) -> None: def pickle_round_robin( self, - control: Union[Struct, Hashable], + control: Struct | Hashable, ) -> None: encoded = pickle.dumps(control, protocol=pickle.HIGHEST_PROTOCOL) decoded = pickle.loads(encoded) self.assertIsInstance(decoded, type(control)) self.assertEqual(control, decoded) + # tests py3 auto-migrate backwards compatibility + # bytes produced by pickle.dumps(control, protocol=pickle.HIGHEST_PROTOCOL) + # when run in normal mode (auto-migrate off) + def assert_unpickle_compat(self, stored: bytes, control: Struct | Hashable) -> None: + decoded = None + try: + decoded = pickle.loads(stored) + except pickle.UnpicklingError: + self.fail( + f"failed to unpickle {stored=}" + f"encoded control={pickle.dumps(control, protocol=pickle.HIGHEST_PROTOCOL)}" + ) + self.assertIsInstance(decoded, type(control)) + self.assertEqual( + control, + decoded, + ) + def test_serialize_easy_struct(self) -> None: control = easy(val=5, val_list=[1, 2, 3, 4]) fixtures: Mapping[Protocol, bytes] = { @@ -191,6 +209,18 @@ def test_pickle_easy_struct(self) -> None: control = easy(val=0, val_list=[5, 6, 7]) self.pickle_round_robin(control) + def test_unpickle_stored_easy_struct(self) -> None: + control = easy(val=0, val_list=[5, 6, 7]) + # string produced with pickle.dumps(control, protocol=pickle.HIGHEST_PROTOCOL) + # when run in normal mode (auto-migrate off) + stored = ( + b"\x80\x05\x95U\x00\x00\x00\x00\x00\x00\x00\x8c\x15thrift.py3." + b"serializer\x94\x8c\x0bdeserialize\x94\x93\x94\x8c\rtesting.types" + b"\x94\x8c\x04easy\x94\x93\x94C\x0c\x15\x00\x195\n\x0c\x0e,\x00\x16" + b"\x00\x00\x94\x86\x94R\x94." + ) + self.assert_unpickle_compat(stored, control) + def test_serialize_hard_struct(self) -> None: control = hard( val=0, val_list=[1, 2, 3, 4], name="foo", an_int=Integers(tiny=1) @@ -231,6 +261,29 @@ def test_pickle_Integers_union(self) -> None: control = Integers(large=2**32) self.pickle_round_robin(control) + def test_unpickle_stored_Integers_union(self) -> None: + control = Integers(large=2**32) + stored = ( + b"\x80\x05\x95T\x00\x00\x00\x00\x00\x00\x00\x8c\x15thrift.py3." + b"serializer\x94\x8c\x0bdeserialize\x94\x93\x94\x8c\rtesting.types" + b"\x94\x8c\x08Integers\x94\x93\x94C\x07F\x80\x80\x80\x80 \x00\x94" + b"\x86\x94R\x94." + ) + self.assert_unpickle_compat(stored, control) + + def test_pickle_enum(self) -> None: + control = MyEnum.ME1 + self.pickle_round_robin(control) + + def test_unpickle_stored_enum(self) -> None: + control = MyEnum.ME1 + stored = ( + b"\x80\x05\x95E\x00\x00\x00\x00\x00\x00\x00\x8c0apache.thrift.test." + b"terse_write.terse_write.types\x94\x8c\x06MyEnum\x94\x93\x94K\x01" + b"\x85\x94R\x94." + ) + self.assert_unpickle_compat(stored, control) + def test_pickle_sequence(self) -> None: control = I32List([1, 2, 3, 4]) self.pickle_round_robin(control) @@ -240,14 +293,41 @@ def test_pickle_sequence(self) -> None: assert data self.pickle_round_robin(data) + def test_unpickle_stored_sequence(self) -> None: + control = I32List([1, 2, 3, 4]) + stored = ( + b"\x80\x05\x95/\x00\x00\x00\x00\x00\x00\x00\x8c\rtesting.types" + b"\x94\x8c\tList__i32\x94\x93\x94]\x94(K\x01K\x02K\x03K\x04e" + b"\x85\x94R\x94." + ) + self.assert_unpickle_compat(stored, control) + def test_pickle_set(self) -> None: control = SetI32({1, 2, 3, 4}) self.pickle_round_robin(control) + def test_unpickle_stored_set(self) -> None: + control = SetI32({1, 2, 3, 4}) + stored = ( + b"\x80\x05\x95.\x00\x00\x00\x00\x00\x00\x00\x8c\rtesting.types" + b"\x94\x8c\x08Set__i32\x94\x93\x94\x8f\x94(K\x01K\x02K\x03K\x04" + b"\x90\x85\x94R\x94." + ) + self.assert_unpickle_compat(stored, control) + def test_pickle_mapping(self) -> None: control = StrStrMap({"test": "test", "foo": "bar"}) self.pickle_round_robin(control) + def test_unpickle_stored_mapping(self) -> None: + control = StrStrMap({"test": "test", "foo": "bar"}) + stored = ( + b"\x80\x05\x95E\x00\x00\x00\x00\x00\x00\x00\x8c\rtesting.types" + b"\x94\x8c\x12Map__string_string\x94\x93\x94}\x94(\x8c\x04test" + b"\x94h\x04\x8c\x03foo\x94\x8c\x03bar\x94u\x85\x94R\x94." + ) + self.assert_unpickle_compat(stored, control) + def test_deserialize_with_length(self) -> None: control = easy(val=5, val_list=[1, 2, 3, 4, 5]) self.with_length_round_robin(control)