Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give user registered types priority when encoding / decoding JSON #2188

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions kombu/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,20 @@ class JSONEncoder(json.JSONEncoder):
"""Kombu custom json encoder."""

def default(self, o):
for t, (marker, encoder) in _encoders.items():
if isinstance(o, t):
return (
encoder(o) if marker is None else _as(marker, encoder(o))
)

reducer = getattr(o, "__json__", None)
if reducer is not None:
return reducer()

if isinstance(o, textual_types):
return str(o)

for t, (marker, encoder) in _encoders.items():
for t, (marker, encoder) in _default_encoders.items():
if isinstance(o, t):
return (
encoder(o) if marker is None else _as(marker, encoder(o))
Expand Down Expand Up @@ -66,7 +72,7 @@ def dumps(
def object_hook(o: dict):
"""Hook function to perform custom deserialization."""
if o.keys() == {"__type__", "__value__"}:
decoder = _decoders.get(o["__type__"])
decoder = _decoders.get(o["__type__"]) or _default_decoders.get(o["__type__"])
if decoder:
return decoder(o["__value__"])
else:
Expand Down Expand Up @@ -97,6 +103,16 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
T = TypeVar("T")
EncodedT = TypeVar("EncodedT")

# Separate user registered types from Kombu registered types to allow us to give preference to user types
_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {}

_default_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_default_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}


def register_type(
t: type[T],
Expand All @@ -110,32 +126,40 @@ def register_type(
is not placed in an envelope, so `decoder` is unnecessary. Decoding must
instead be handled outside this library.
"""
_encoders[t] = (marker, encoder)
if marker is not None:
_decoders[marker] = decoder
_register_type(t, marker, encoder, decoder, is_default_encoder=False)


_encoders: dict[type, tuple[str | None, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}
def _register_type(
t: type[T],
marker: str | None,
encoder: Callable[[T], EncodedT],
decoder: Callable[[EncodedT], T] = lambda d: d,
is_default_encoder: bool = True,
):
if is_default_encoder:
_default_encoders[t] = (marker, encoder)
if marker is not None:
_default_decoders[marker] = decoder
else:
_encoders[t] = (marker, encoder)
if marker is not None:
_decoders[marker] = decoder


def _register_default_types():
# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.isoformat,
datetime.fromisoformat)
register_type(
_register_type(datetime, "datetime", datetime.isoformat,
datetime.fromisoformat)
_register_type(
date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.fromisoformat(o).date(),
lambda o: datetime.fromisoformat(o).date()
)
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
_register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
_register_type(Decimal, "decimal", str, Decimal)
_register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
Expand Down
9 changes: 9 additions & 0 deletions t/unit/utils/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def test_register_type_overrides_defaults(self):
loaded_value = loads(dumps({'u': value}))
assert loaded_value == {'u': "custom"}

def test_register_type_takes_priority(self):
class MyDecimal(Decimal):
pass

register_type(MyDecimal, "mydecimal", str, MyDecimal)
original = {'md': MyDecimal('3314132.13363235235324234123213213214134')}
loaded_value = loads(dumps(original))
assert original == loaded_value

def test_register_type_with_new_type(self):
# Guaranteed never before seen type
@dataclass()
Expand Down