Skip to content

Commit

Permalink
Implement OpenAI Structured Output (#225)
Browse files Browse the repository at this point in the history
* implement so

* bug fixes and typing

* inject schema for non-openai clients

* correct datetime format

* remove List keyword

* Refactor node_operations.py to use updated prompt_library functions

* update example
  • Loading branch information
danielchalef authored Dec 5, 2024
1 parent 427c73d commit 567a8ab
Show file tree
Hide file tree
Showing 19 changed files with 249 additions and 181 deletions.
59 changes: 28 additions & 31 deletions examples/langgraph-agent/agent.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion graphiti_core/cross_encoder/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]
passages (list[str]): A list of passages to rank.
Returns:
List[tuple[str, float]]: A list of tuples containing the passage and its score,
list[tuple[str, float]]: A list of tuples containing the passage and its score,
sorted in descending order of relevance.
"""
pass
5 changes: 4 additions & 1 deletion graphiti_core/llm_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import anthropic
from anthropic import AsyncAnthropic
from pydantic import BaseModel

from ..prompts.models import Message
from .client import LLMClient
Expand All @@ -46,7 +47,9 @@ def __init__(self, config: LLMConfig | None = None, cache: bool = False):
max_retries=1,
)

async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
{'role': 'assistant', 'content': '{'}
Expand Down
25 changes: 20 additions & 5 deletions graphiti_core/llm_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import httpx
from diskcache import Cache
from pydantic import BaseModel
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential

from ..prompts.models import Message
Expand Down Expand Up @@ -66,14 +67,18 @@ def __init__(self, config: LLMConfig | None, cache: bool = False):
else None,
reraise=True,
)
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
async def _generate_response_with_retry(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages)
return await self._generate_response(messages, response_model)
except (httpx.HTTPStatusError, RateLimitError) as e:
raise e

@abstractmethod
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
pass

def _get_cache_key(self, messages: list[Message]) -> str:
Expand All @@ -82,7 +87,17 @@ def _get_cache_key(self, messages: list[Message]) -> str:
key_str = f'{self.model}:{message_str}'
return hashlib.md5(key_str.encode()).hexdigest()

async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
if response_model is not None:
serialized_model = json.dumps(response_model.model_json_schema())
messages[
-1
].content += (
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
)

if self.cache_enabled:
cache_key = self._get_cache_key(messages)

Expand All @@ -91,7 +106,7 @@ async def generate_response(self, messages: list[Message]) -> dict[str, typing.A
logger.debug(f'Cache hit for {cache_key}')
return cached_response

response = await self._generate_response_with_retry(messages)
response = await self._generate_response_with_retry(messages, response_model)

if self.cache_enabled:
self.cache_dir.set(cache_key, response)
Expand Down
8 changes: 8 additions & 0 deletions graphiti_core/llm_client/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ class RateLimitError(Exception):
def __init__(self, message='Rate limit exceeded. Please try again later.'):
self.message = message
super().__init__(self.message)


class RefusalError(Exception):
"""Exception raised when the LLM refuses to generate a response."""

def __init__(self, message: str):
self.message = message
super().__init__(self.message)
5 changes: 4 additions & 1 deletion graphiti_core/llm_client/groq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import groq
from groq import AsyncGroq
from groq.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

from ..prompts.models import Message
from .client import LLMClient
Expand All @@ -43,7 +44,9 @@ def __init__(self, config: LLMConfig | None = None, cache: bool = False):

self.client = AsyncGroq(api_key=config.api_key)

async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
msgs: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
Expand Down
36 changes: 29 additions & 7 deletions graphiti_core/llm_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
limitations under the License.
"""

import json
import logging
import typing

import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
from .errors import RateLimitError, RefusalError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,6 +65,10 @@ def __init__(
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
# removed caching to simplify the `generate_response` override
if cache:
raise NotImplementedError('Caching is not implemented for OpenAI')

if config is None:
config = LLMConfig()

Expand All @@ -75,25 +79,43 @@ def __init__(
else:
self.client = client

async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
response = await self.client.chat.completions.create(
response = await self.client.beta.chat.completions.parse(
model=self.model or DEFAULT_MODEL,
messages=openai_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format={'type': 'json_object'},
response_format=response_model, # type: ignore
)
result = response.choices[0].message.content or ''
return json.loads(result)

response_object = response.choices[0].message

if response_object.parsed:
return response_object.parsed.model_dump()
elif response_object.refusal:
raise RefusalError(response_object.refusal)
else:
raise Exception('No response from LLM')
except openai.LengthFinishReasonError as e:
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
response = await self._generate_response(messages, response_model)

return response
37 changes: 20 additions & 17 deletions graphiti_core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,30 @@
"""

import json
from typing import Any, Protocol, TypedDict
from typing import Any, Optional, Protocol, TypedDict

from pydantic import BaseModel, Field

from .models import Message, PromptFunction, PromptVersion


class EdgeDuplicate(BaseModel):
is_duplicate: bool = Field(..., description='true or false')
uuid: Optional[str] = Field(
None,
description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
)


class UniqueFact(BaseModel):
uuid: str = Field(..., description='unique identifier of the fact')
fact: str = Field(..., description='fact of a unique edge')


class UniqueFacts(BaseModel):
unique_facts: list[UniqueFact]


class Prompt(Protocol):
edge: PromptVersion
edge_list: PromptVersion
Expand Down Expand Up @@ -56,12 +75,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
}}
""",
),
]
Expand Down Expand Up @@ -90,16 +103,6 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
3. Facts will often discuss the same or similar relation between identical entities
4. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
facts should be in the response
Respond with a JSON object in the following format:
{{
"unique_facts": [
{{
"uuid": "unique identifier of the fact",
"fact": "fact of a unique edge"
}}
]
}}
""",
),
]
Expand Down
16 changes: 15 additions & 1 deletion graphiti_core/prompts/dedupe_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,25 @@
"""

import json
from typing import Any, Protocol, TypedDict
from typing import Any, Optional, Protocol, TypedDict

from pydantic import BaseModel, Field

from .models import Message, PromptFunction, PromptVersion


class NodeDuplicate(BaseModel):
is_duplicate: bool = Field(..., description='true or false')
uuid: Optional[str] = Field(
None,
description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null",
)
name: str = Field(
...,
description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)",
)


class Prompt(Protocol):
node: PromptVersion
node_list: PromptVersion
Expand Down
31 changes: 17 additions & 14 deletions graphiti_core/prompts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,26 @@
import json
from typing import Any, Protocol, TypedDict

from pydantic import BaseModel, Field

from .models import Message, PromptFunction, PromptVersion


class QueryExpansion(BaseModel):
query: str = Field(..., description='query optimized for database search')


class QAResponse(BaseModel):
ANSWER: str = Field(..., description='how Alice would answer the question')


class EvalResponse(BaseModel):
is_correct: bool = Field(..., description='boolean if the answer is correct or incorrect')
reasoning: str = Field(
..., description='why you determined the response was correct or incorrect'
)


class Prompt(Protocol):
qa_prompt: PromptVersion
eval_prompt: PromptVersion
Expand All @@ -41,10 +58,6 @@ def query_expansion(context: dict[str, Any]) -> list[Message]:
<QUESTION>
{json.dumps(context['query'])}
</QUESTION>
respond with a JSON object in the following format:
{{
"query": "query optimized for database search"
}}
"""
return [
Message(role='system', content=sys_prompt),
Expand All @@ -67,10 +80,6 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
<QUESTION>
{context['query']}
</QUESTION>
respond with a JSON object in the following format:
{{
"ANSWER": "how Alice would answer the question"
}}
"""
return [
Message(role='system', content=sys_prompt),
Expand All @@ -96,12 +105,6 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
<RESPONSE>
{context['response']}
</RESPONSE>
respond with a JSON object in the following format:
{{
"is_correct": "boolean if the answer is correct or incorrect"
"reasoning": "why you determined the response was correct or incorrect"
}}
"""
return [
Message(role='system', content=sys_prompt),
Expand Down
22 changes: 15 additions & 7 deletions graphiti_core/prompts/extract_edge_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@
limitations under the License.
"""

from typing import Any, Protocol, TypedDict
from typing import Any, Optional, Protocol, TypedDict

from pydantic import BaseModel, Field

from .models import Message, PromptFunction, PromptVersion


class EdgeDates(BaseModel):
valid_at: Optional[str] = Field(
None,
description='The date and time when the relationship described by the edge fact became true or was established. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
)
invalid_at: Optional[str] = Field(
None,
description='The date and time when the relationship described by the edge fact stopped being true or ended. YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null.',
)


class Prompt(Protocol):
v1: PromptVersion

Expand Down Expand Up @@ -60,7 +73,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
Analyze the conversation and determine if there are dates that are part of the edge fact. Only set dates if they explicitly relate to the formation or alteration of the relationship itself.
Guidelines:
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ) for datetimes.
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
Expand All @@ -69,11 +82,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
8. If only year is mentioned, use January 1st of that year at 00:00:00.
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
Respond with a JSON object:
{{
"valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
"invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
}}
""",
),
]
Expand Down
Loading

0 comments on commit 567a8ab

Please sign in to comment.