Skip to content

Commit

Permalink
bug fixes and typing
Browse files Browse the repository at this point in the history
  • Loading branch information
danielchalef committed Dec 5, 2024
1 parent 3d4e26f commit cf8b4f7
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 55 deletions.
5 changes: 2 additions & 3 deletions graphiti_core/llm_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import typing

import openai
from openai import AsyncOpenAI, NotGiven
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

Expand Down Expand Up @@ -82,7 +82,6 @@ def __init__(
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
response_format = response_model if response_model else NotGiven
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
Expand All @@ -95,7 +94,7 @@ async def _generate_response(
messages=openai_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format=response_format,
response_format=response_model, # type: ignore
)

response_object = response.choices[0].message
Expand Down
5 changes: 2 additions & 3 deletions graphiti_core/prompts/extract_edge_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

from datetime import datetime
from typing import Any, Optional, Protocol, TypedDict

from pydantic import BaseModel, Field
Expand All @@ -23,11 +22,11 @@


class EdgeDates(BaseModel):
valid_at: Optional[datetime] = Field(
valid_at: Optional[str] = Field(
None,
description='The date and time when the relationship described by the edge fact became true or was established',
)
invalid_at: Optional[datetime] = Field(
invalid_at: Optional[str] = Field(
None,
description='The date and time when the relationship described by the edge fact stopped being true or ended',
)
Expand Down
26 changes: 5 additions & 21 deletions graphiti_core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class ExtractedNodes(BaseModel):
extracted_node_names: List[str] = Field(..., description='Name of the extracted entity')

Check failure on line 26 in graphiti_core/prompts/extract_nodes.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

graphiti_core/prompts/extract_nodes.py:26:27: UP006 Use `list` instead of `List` for type annotation


class MissedEntities(BaseModel):
missed_entities: List[str] = Field(..., description="Names of entities that weren't extracted")

Check failure on line 30 in graphiti_core/prompts/extract_nodes.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

graphiti_core/prompts/extract_nodes.py:30:22: UP006 Use `list` instead of `List` for type annotation


class Prompt(Protocol):
extract_message: PromptVersion
extract_json: PromptVersion
Expand Down Expand Up @@ -62,11 +66,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
5. Be as explicit as possible in your node names, using full names.
6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
Respond with a JSON object in the following format:
{{
"extracted_node_names": ["Name of the extracted entity", ...],
}}
"""
return [
Message(role='system', content=sys_prompt),
Expand All @@ -93,11 +92,6 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
2. Do NOT extract any properties that contain dates
Respond with a JSON object in the following format:
{{
"extracted_node_names": ["Name of the extracted entity", ...],
}}
"""
return [
Message(role='system', content=sys_prompt),
Expand All @@ -122,11 +116,6 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
2. Avoid creating nodes for relationships or actions.
3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
Respond with a JSON object in the following format:
{{
"extracted_node_names": ["Name of the extracted entity", ...],
}}
"""
return [
Message(role='system', content=sys_prompt),
Expand All @@ -150,12 +139,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
</EXTRACTED ENTITIES>
Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
extracted:
Respond with a JSON object in the following format:
{{
"missed_entities": [ "name of entity that wasn't extracted", ...]
}}
extracted.
"""
return [
Message(role='system', content=sys_prompt),
Expand Down
40 changes: 14 additions & 26 deletions graphiti_core/prompts/invalidate_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@
limitations under the License.
"""

from typing import Any, Protocol, TypedDict
from typing import Any, List, Protocol, TypedDict

Check failure on line 17 in graphiti_core/prompts/invalidate_edges.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP035)

graphiti_core/prompts/invalidate_edges.py:17:1: UP035 `typing.List` is deprecated, use `list` instead

from pydantic import BaseModel, Field

from .models import Message, PromptFunction, PromptVersion


class InvalidatedEdge(BaseModel):
uuid: str = Field(..., description='The UUID of the edge to be invalidated')
fact: str = Field(..., description='Updated fact of the edge')


class InvalidatedEdges(BaseModel):
invalidated_edges: List[InvalidatedEdge] = Field(

Check failure on line 30 in graphiti_core/prompts/invalidate_edges.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

graphiti_core/prompts/invalidate_edges.py:30:24: UP006 Use `list` instead of `List` for type annotation
..., description='List of edges that should be invalidated'
)


class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
Expand Down Expand Up @@ -56,18 +69,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
{context['new_edges']}
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), START_DATE (END_DATE, optional))"
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
"fact": "Updated fact of the edge"
}}
]
}}
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]
Expand All @@ -89,19 +90,6 @@ def v2(context: dict[str, Any]) -> list[Message]:
New Edge:
{context['new_edge']}
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"uuid": "The UUID of the edge to be invalidated",
"fact": "Updated fact of the edge"
}}
]
}}
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]
Expand Down
5 changes: 4 additions & 1 deletion graphiti_core/utils/maintenance/temporal_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from graphiti_core.nodes import EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.extract_edge_dates import EdgeDates
from graphiti_core.prompts.invalidate_edges import InvalidatedEdges

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,7 +79,9 @@ async def get_edge_contradictions(

context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}

llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v2(context))
llm_response = await llm_client.generate_response(
prompt_library.invalidate_edges.v2(context), response_model=InvalidatedEdges
)

contradicted_edge_data = llm_response.get('invalidated_edges', [])

Expand Down
3 changes: 2 additions & 1 deletion tests/utils/maintenance/test_temporal_operations_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ async def test_extract_edge_dates():
setup_llm_client(), new_edge, episode, previous_episodes
)

assert valid_at == episode.valid_at
# the prompt specifies the format of the datetime to be YYYY-MM-DDTHH:MM:SS, without microseconds
assert valid_at.replace(microsecond=0) == episode.valid_at.replace(microsecond=0)
assert invalid_at is None


Expand Down

0 comments on commit cf8b4f7

Please sign in to comment.