Skip to content

Commit

Permalink
Patch for pydantic 1.10.15 (#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
harry-cohere authored and billytrend-cohere committed Apr 3, 2024
1 parent 856a4c3 commit 0f7e675
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/cohere/core/jsonable_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

try:
import pydantic.v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore
import pydantic

from .datetime_utils import serialize_datetime

IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")

if IS_PYDANTIC_V2:
import pydantic.v1 as pydantic_v1 # type: ignore
else:
import pydantic as pydantic_v1 # type: ignore

SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]

Expand All @@ -36,7 +40,7 @@ def generate_encoders_by_class_tuples(
return encoders_by_class_tuples


encoders_by_class_tuples = generate_encoders_by_class_tuples(pydantic.json.ENCODERS_BY_TYPE)
encoders_by_class_tuples = generate_encoders_by_class_tuples(pydantic_v1.json.ENCODERS_BY_TYPE)


def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None) -> Any:
Expand All @@ -48,7 +52,7 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any]
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if isinstance(obj, pydantic.BaseModel):
if isinstance(obj, pydantic_v1.BaseModel):
encoder = getattr(obj.__config__, "json_encoders", {})
if custom_encoder:
encoder.update(custom_encoder)
Expand Down Expand Up @@ -84,8 +88,8 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any]
encoded_list.append(jsonable_encoder(item, custom_encoder=custom_encoder))
return encoded_list

if type(obj) in pydantic.json.ENCODERS_BY_TYPE:
return pydantic.json.ENCODERS_BY_TYPE[type(obj)](obj)
if type(obj) in pydantic_v1.json.ENCODERS_BY_TYPE:
return pydantic_v1.json.ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
Expand Down

0 comments on commit 0f7e675

Please sign in to comment.