Skip to content

Commit

Permalink
handle pydantic v1 gracefully (#195)
Browse files Browse the repository at this point in the history
* check pydantic version

* forward ref the ZepModel type

* mod init

* lint
  • Loading branch information
danielchalef authored Jun 24, 2024
1 parent e81bf98 commit 7d9e7bd
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 61 deletions.
117 changes: 68 additions & 49 deletions src/zep_cloud/external_clients/memory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import datetime
import json
import typing
from packaging import version

import pydantic

from zep_cloud.core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from zep_cloud.extractor.models import ZepModel
from zep_cloud.memory.client import (
AsyncMemoryClient as AsyncBaseMemoryClient,
)
from zep_cloud.memory.client import (
MemoryClient as BaseMemoryClient,
)

if typing.TYPE_CHECKING:
from zep_cloud.extractor.models import ZepModel

MIN_PYDANTIC_VERSION = "2.0"


class MemoryClient(BaseMemoryClient):
def __init__(self, *, client_wrapper: SyncClientWrapper):
Expand All @@ -19,7 +26,7 @@ def __init__(self, *, client_wrapper: SyncClientWrapper):
def extract(
self,
session_id: str,
model: ZepModel,
model: "ZepModel",
current_date_time: typing.Optional[datetime.datetime] = None,
last_n: int = 4,
validate: bool = False,
Expand Down Expand Up @@ -65,16 +72,22 @@ class CustomerInfo(ZepModel):
print(customer_data.name) # Access extracted and validated customer name
"""

if version.parse(pydantic.VERSION) < version.parse(MIN_PYDANTIC_VERSION):
raise RuntimeError(
f"Pydantic version {MIN_PYDANTIC_VERSION} or greater is required."
)

model_schema = json.dumps(model.model_json_schema())

result = self.extract_data(
session_id=session_id,
model_schema=model_schema,
validate=validate,
last_n=last_n,
current_date_time=current_date_time.isoformat()
if current_date_time
else None,
current_date_time=(
current_date_time.isoformat() if current_date_time else None
),
)

return model.model_validate(result)
Expand All @@ -87,62 +100,68 @@ def __init__(self, *, client_wrapper: AsyncClientWrapper):
async def extract(
self,
session_id: str,
model: ZepModel,
model: "ZepModel",
current_date_time: typing.Optional[datetime.datetime] = None,
last_n: int = 4,
validate: bool = False,
):
"""Extracts structured data from a session based on a ZepModel schema.
This method retrieves data based on a given model and session details.
It then returns the extracted and validated data as an instance of the given ZepModel.
Parameters
----------
session_id: str
Session ID.
model: ZepModel
An instance of a ZepModel subclass defining the expected data structure and field types.
current_date_time: typing.Optional[datetime.datetime]
Your current date and time in ISO 8601 format including timezone.
This is used for determining relative dates.
last_n: typing.Optional[int]
The number of messages in the chat history from which to extract data.
validate: typing.Optional[bool]
Validate that the extracted data is present in the dialog and correct per the field description.
Mitigates hallucination, but is slower and may result in false negatives.
Returns
-------
ZepModel: An instance of the provided ZepModel subclass populated with the
extracted and validated data.
Examples
--------
class CustomerInfo(ZepModel):
name: Optional[ZepText] = Field(description="Customer name", default=None)
name: Optional[ZepEmail] = Field(description="Customer email", default=None)
signup_date: Optional[ZepDate] = Field(description="Customer Sign up date", default=None)
client = AsyncMemoryClient(...)
customer_data = await client.memory.extract(
session_id="session123",
model=CustomerInfo(),
current_date_time=datetime.datetime.now(), # Filter data up to now
)
print(customer_data.name) # Access extracted and validated customer name
"""
This method retrieves data based on a given model and session details.
It then returns the extracted and validated data as an instance of the given ZepModel.
Parameters
----------
session_id: str
Session ID.
model: ZepModel
An instance of a ZepModel subclass defining the expected data structure and field types.
current_date_time: typing.Optional[datetime.datetime]
Your current date and time in ISO 8601 format including timezone.
This is used for determining relative dates.
last_n: typing.Optional[int]
The number of messages in the chat history from which to extract data.
validate: typing.Optional[bool]
Validate that the extracted data is present in the dialog and correct per the field description.
Mitigates hallucination, but is slower and may result in false negatives.
Returns
-------
ZepModel: An instance of the provided ZepModel subclass populated with the
extracted and validated data.
Examples
--------
class CustomerInfo(ZepModel):
name: Optional[ZepText] = Field(description="Customer name", default=None)
name: Optional[ZepEmail] = Field(description="Customer email", default=None)
signup_date: Optional[ZepDate] = Field(description="Customer Sign up date", default=None)
client = AsyncMemoryClient(...)
customer_data = await client.memory.extract(
session_id="session123",
model=CustomerInfo(),
current_date_time=datetime.datetime.now(), # Filter data up to now
)
print(customer_data.name) # Access extracted and validated customer name
"""

if version.parse(pydantic.VERSION) < version.parse(MIN_PYDANTIC_VERSION):
raise RuntimeError(
f"Pydantic version {MIN_PYDANTIC_VERSION} or greater is required."
)

model_schema = json.dumps(model.model_json_schema())

result = await self.extract_data(
session_id=session_id,
model_schema=model_schema,
validate=validate,
last_n=last_n,
current_date_time=current_date_time.isoformat()
if current_date_time
else None,
current_date_time=(
current_date_time.isoformat() if current_date_time else None
),
)

return model.model_validate(result)
49 changes: 37 additions & 12 deletions src/zep_cloud/extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
from zep_cloud.extractor.models import (
ZepModel,
ZepText,
ZepNumber,
ZepFloat,
ZepRegex,
ZepZipCode,
ZepDate,
ZepDateTime,
ZepEmail,
ZepPhoneNumber,
)
# mypy: disable-error-code=no-redef

from typing import Type

# Zep Extraction requires Pydantic v2. If v2 is not installed, catch the error
# and set the variables to PydanticV2Required


class PydanticV2Required:
def __init__(self, *args, **kwargs):
raise RuntimeError("Pydantic v2 is required to use this class.")


try:
from zep_cloud.extractor.models import (
ZepModel,
ZepText,
ZepNumber,
ZepFloat,
ZepRegex,
ZepZipCode,
ZepDate,
ZepDateTime,
ZepEmail,
ZepPhoneNumber,
)
except ImportError:
ZepModel: Type = PydanticV2Required
ZepText: Type = PydanticV2Required
ZepNumber: Type = PydanticV2Required
ZepFloat: Type = PydanticV2Required
ZepRegex: Type = PydanticV2Required
ZepZipCode: Type = PydanticV2Required
ZepDate: Type = PydanticV2Required
ZepDateTime: Type = PydanticV2Required
ZepEmail: Type = PydanticV2Required
ZepPhoneNumber: Type = PydanticV2Required

__all__ = [
"ZepModel",
Expand Down

0 comments on commit 7d9e7bd

Please sign in to comment.