From 7a20cbdad3568b4aa1075c649f9af8c8df396d32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9amus=20=C3=93=20Ceanainn?= Date: Thu, 7 Nov 2024 12:28:48 +0000 Subject: [PATCH] Give user registered types priority when encoding / decoding JSON --- kombu/utils/json.py | 58 +++++++++++++++++++++++++++------------ t/unit/utils/test_json.py | 9 ++++++ 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/kombu/utils/json.py b/kombu/utils/json.py index 46326c109..ad8cf73e2 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -23,6 +23,12 @@ class JSONEncoder(json.JSONEncoder): """Kombu custom json encoder.""" def default(self, o): + for t, (marker, encoder) in _encoders.items(): + if isinstance(o, t): + return ( + encoder(o) if marker is None else _as(marker, encoder(o)) + ) + reducer = getattr(o, "__json__", None) if reducer is not None: return reducer() @@ -30,7 +36,7 @@ def default(self, o): if isinstance(o, textual_types): return str(o) - for t, (marker, encoder) in _encoders.items(): + for t, (marker, encoder) in _default_encoders.items(): if isinstance(o, t): return ( encoder(o) if marker is None else _as(marker, encoder(o)) @@ -66,7 +72,7 @@ def dumps( def object_hook(o: dict): """Hook function to perform custom deserialization.""" if o.keys() == {"__type__", "__value__"}: - decoder = _decoders.get(o["__type__"]) + decoder = _decoders.get(o["__type__"]) or _default_decoders.get(o["__type__"]) if decoder: return decoder(o["__value__"]) else: @@ -97,6 +103,16 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): T = TypeVar("T") EncodedT = TypeVar("EncodedT") +# Separate user registered types from Kombu registered types to allow us to give preference to user types +_encoders: dict[type, tuple[str | None, EncoderT]] = {} +_decoders: dict[str, DecoderT] = {} + +_default_encoders: dict[type, tuple[str | None, EncoderT]] = {} +_default_decoders: dict[str, DecoderT] = { + "bytes": lambda o: o.encode("utf-8"), + "base64": lambda o: base64.b64decode(o.encode("utf-8")), +} + def register_type( t: type[T], @@ -110,32 +126,40 @@ def register_type( is not placed in an envelope, so `decoder` is unnecessary. Decoding must instead be handled outside this library. """ - _encoders[t] = (marker, encoder) - if marker is not None: - _decoders[marker] = decoder + _register_type(t, marker, encoder, decoder, is_default_encoder=False) -_encoders: dict[type, tuple[str | None, EncoderT]] = {} -_decoders: dict[str, DecoderT] = { - "bytes": lambda o: o.encode("utf-8"), - "base64": lambda o: base64.b64decode(o.encode("utf-8")), -} +def _register_type( + t: type[T], + marker: str | None, + encoder: Callable[[T], EncodedT], + decoder: Callable[[EncodedT], T] = lambda d: d, + is_default_encoder: bool = True, +): + if is_default_encoder: + _default_encoders[t] = (marker, encoder) + if marker is not None: + _default_decoders[marker] = decoder + else: + _encoders[t] = (marker, encoder) + if marker is not None: + _decoders[marker] = decoder def _register_default_types(): # NOTE: datetime should be registered before date, # because datetime is also instance of date. - register_type(datetime, "datetime", datetime.isoformat, - datetime.fromisoformat) - register_type( + _register_type(datetime, "datetime", datetime.isoformat, + datetime.fromisoformat) + _register_type( date, "date", lambda o: o.isoformat(), - lambda o: datetime.fromisoformat(o).date(), + lambda o: datetime.fromisoformat(o).date() ) - register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) - register_type(Decimal, "decimal", str, Decimal) - register_type( + _register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) + _register_type(Decimal, "decimal", str, Decimal) + _register_type( uuid.UUID, "uuid", lambda o: {"hex": o.hex}, diff --git a/t/unit/utils/test_json.py b/t/unit/utils/test_json.py index 579ab64ab..172ddefd5 100644 --- a/t/unit/utils/test_json.py +++ b/t/unit/utils/test_json.py @@ -95,6 +95,15 @@ def test_register_type_overrides_defaults(self): loaded_value = loads(dumps({'u': value})) assert loaded_value == {'u': "custom"} + def test_register_type_takes_priority(self): + class MyDecimal(Decimal): + pass + + register_type(MyDecimal, "mydecimal", str, MyDecimal) + original = {'md': MyDecimal('3314132.13363235235324234123213213214134')} + loaded_value = loads(dumps(original)) + assert original == loaded_value + def test_register_type_with_new_type(self): # Guaranteed never before seen type @dataclass()