From 886c33b9ccea206c15b3c55a57463c9f7d6bb521 Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Wed, 21 Aug 2024 17:43:04 -0400 Subject: [PATCH 01/10] wip --- core/graphiti.py | 30 +++++++++++++++++++---- core/utils/maintenance/node_operations.py | 24 ++++++++++++++---- runner.py | 1 + 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index 14b20641..daec7c7c 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -95,6 +95,9 @@ async def add_episode( extracted_nodes = await extract_nodes( self.llm_client, episode, previous_episodes ) + logger.info( + f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" + ) # Calculate Embeddings @@ -105,16 +108,16 @@ async def add_episode( logger.info( f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" ) - new_nodes = await dedupe_extracted_nodes( + touched_nodes, (_, brand_new_nodes) = await dedupe_extracted_nodes( self.llm_client, extracted_nodes, existing_nodes ) logger.info( - f"Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}" + f"Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}" ) - nodes.extend(new_nodes) + nodes.extend(touched_nodes) extracted_edges = await extract_edges( - self.llm_client, episode, new_nodes, previous_episodes + self.llm_client, episode, touched_nodes, previous_episodes ) await asyncio.gather( @@ -131,6 +134,11 @@ async def add_episode( self.llm_client, extracted_edges, existing_edges ) + edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] + for edge in deduped_edges: + edge_touched_node_uuids.append(edge.source_node_uuid) + edge_touched_node_uuids.append(edge.target_node_uuid) + ( old_edges_with_nodes_pending_invalidation, new_edges_with_nodes, @@ -144,8 +152,17 @@ async def add_episode( new_edges_with_nodes, ) + for edge in invalidated_edges: + edge_touched_node_uuids.append(edge.source_node_uuid) + edge_touched_node_uuids.append(edge.target_node_uuid) + entity_edges.extend(invalidated_edges) + edge_touched_node_uuids = list(set(edge_touched_node_uuids)) + edge_touched_nodes = [ + node for node in nodes if node.uuid in edge_touched_node_uuids + ] + logger.info( f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}" ) @@ -153,10 +170,13 @@ async def add_episode( logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}") entity_edges.extend(deduped_edges) + logger.info( + f"building episodic edges to nodes {[(n.name, n.uuid) for n in edge_touched_nodes]}" + ) episodic_edges.extend( build_episodic_edges( # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them - nodes, + edge_touched_nodes, episode, now, ) diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index 1b67bf56..8318f9a3 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -109,6 +109,7 @@ async def dedupe_extracted_nodes( existing_nodes: list[EntityNode], ) -> list[EntityNode]: # build node map + brand_new_nodes_map = {} node_map = {} for node in existing_nodes: node_map[node.name] = node @@ -116,7 +117,7 @@ async def dedupe_extracted_nodes( if node.name in node_map.keys(): continue node_map[node.name] = node - + brand_new_nodes_map[node.name] = node # Prepare context for LLM existing_nodes_context = [ {"name": node.name, "summary": node.summary} for node in existing_nodes @@ -139,9 +140,22 @@ async def dedupe_extracted_nodes( logger.info(f"Deduplicated nodes: {new_nodes_data}") # Get full node data - nodes = [] + adjusted_nodes = [] for node_data in new_nodes_data: node = node_map[node_data["name"]] - nodes.append(node) - - return nodes + adjusted_nodes.append(node) + + brand_new_nodes = [] + for node in new_nodes_data: + if node["name"] in brand_new_nodes_map.keys(): + brand_new_nodes.append(brand_new_nodes_map[node["name"]]) + logger.info(f"Brand new nodes: {[(n.name, n.uuid) for n in brand_new_nodes]}") + + adjusted_existing_nodes = [] + for node in new_nodes_data: + if node["name"] not in brand_new_nodes_map.keys(): + adjusted_existing_nodes.append(node_map[node["name"]]) + logger.info( + f"Adjusted existing nodes: {[(n.name, n.uuid) for n in adjusted_existing_nodes]}" + ) + return adjusted_nodes, (adjusted_existing_nodes, brand_new_nodes) diff --git a/runner.py b/runner.py index 4c3baef2..a9f19bc6 100644 --- a/runner.py +++ b/runner.py @@ -62,6 +62,7 @@ async def main(): episode_body="Paul: I have divorced Jane", source_description="WhatsApp Message", ) + # await client.add_episode( # name="Message 3", # episode_body="Assistant: The best type of apples available are Fuji apples", From 0e46490bfdceda889beaff6365fbca8f9fbee26a Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 00:28:28 -0400 Subject: [PATCH 02/10] wip --- core/graphiti.py | 17 +++----- core/prompts/dedupe_edges.py | 1 + core/prompts/invalidate_edges.py | 18 +++++--- core/utils/bulk_utils.py | 2 +- core/utils/maintenance/edge_operations.py | 2 +- core/utils/maintenance/node_operations.py | 43 ++++++++----------- core/utils/maintenance/temporal_operations.py | 2 + core/utils/search/search_utils.py | 32 +++++++------- runner.py | 19 +++++--- 9 files changed, 71 insertions(+), 65 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index 748b46ca..86c64329 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -132,7 +132,7 @@ async def add_episode( logger.info( f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" ) - touched_nodes, _, (_, brand_new_nodes) = await dedupe_extracted_nodes( + touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes( self.llm_client, extracted_nodes, existing_nodes ) logger.info( @@ -183,10 +183,14 @@ async def add_episode( entity_edges.extend(invalidated_edges) edge_touched_node_uuids = list(set(edge_touched_node_uuids)) - edge_touched_nodes = [ + involved_nodes = [ node for node in nodes if node.uuid in edge_touched_node_uuids ] + logger.info( + f"Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}" + ) + logger.info( f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}" ) @@ -194,17 +198,10 @@ async def add_episode( logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}") entity_edges.extend(deduped_edges) - new_edges = await dedupe_extracted_edges( - self.llm_client, extracted_edges, existing_edges - ) - - logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}") - - entity_edges.extend(new_edges) episodic_edges.extend( build_episodic_edges( # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them - edge_touched_nodes, + involved_nodes, episode, now, ) diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py index adae8bda..a6d92a5a 100644 --- a/core/prompts/dedupe_edges.py +++ b/core/prompts/dedupe_edges.py @@ -40,6 +40,7 @@ def v1(context: dict[str, any]) -> list[Message]: Guidelines: 1. Use both the name and fact of edges to determine if they are duplicates, duplicate edges may have different names + 2. If you encounter facts that are semantically equivalent or very similar, keep the original edge Respond with a JSON object in the following format: {{ diff --git a/core/prompts/invalidate_edges.py b/core/prompts/invalidate_edges.py index 5bad913e..7a048969 100644 --- a/core/prompts/invalidate_edges.py +++ b/core/prompts/invalidate_edges.py @@ -14,14 +14,20 @@ def v1(context: dict[str, any]) -> list[Message]: return [ Message( role="system", - content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.", + content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.", ), Message( role="user", content=f""" - Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges. - Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true. - Do not invalidate relationships merely because they weren't mentioned in new edges. + Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to explicit contradictions in the new edges. + + Important guidelines: + 1. Only mark a relationship as invalid if there is an explicit, direct contradiction in the new edges. + 2. Do not make any assumptions or inferences about relationships. + 3. Do not invalidate edges based on implied changes or personal interpretations. + 4. A new edge does not automatically invalidate an existing edge unless it directly states the opposite. + 5. Different types of relationships can coexist and do not automatically invalidate each other. + 6. Do not invalidate relationships merely because they weren't mentioned in new edges. Existing Edges (sorted by timestamp, newest first): {context['existing_edges']} @@ -36,12 +42,12 @@ def v1(context: dict[str, any]) -> list[Message]: "invalidated_edges": [ {{ "edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)", - "reason": "Brief explanation of why this edge is being invalidated" + "reason": "Brief explanation citing the specific new edge that directly contradicts this edge" }} ] }} - If no relationships need to be invalidated, return an empty list for "invalidated_edges". + If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges". """, ), ] diff --git a/core/utils/bulk_utils.py b/core/utils/bulk_utils.py index a5b361ef..baa188c9 100644 --- a/core/utils/bulk_utils.py +++ b/core/utils/bulk_utils.py @@ -102,7 +102,7 @@ async def dedupe_nodes_bulk( existing_nodes = await get_relevant_nodes(compressed_nodes, driver) - nodes, partial_uuid_map = await dedupe_extracted_nodes( + nodes, partial_uuid_map, _ = await dedupe_extracted_nodes( llm_client, compressed_nodes, existing_nodes ) diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index ec505cb5..dd6cd452 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -219,7 +219,7 @@ async def dedupe_extracted_edges( {"name": edge.name, "fact": edge.fact} for edge in extracted_edges ], } - + logger.info(prompt_library.dedupe_edges.v1(context)) llm_response = await llm_client.generate_response( prompt_library.dedupe_edges.v1(context) ) diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index 609155e4..a12f18ee 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -116,15 +116,15 @@ async def dedupe_extracted_nodes( start = time() # build existing node map - brand_new_nodes_map = {} node_map = {} for node in existing_nodes: node_map[node.name] = node + + # Temp hack + new_nodes_map = {} for node in extracted_nodes: - if node.name in node_map.keys(): - continue - node_map[node.name] = node - brand_new_nodes_map[node.name] = node + new_nodes_map[node.name] = node + # Prepare context for LLM existing_nodes_context = [ {"name": node.name, "summary": node.summary} for node in existing_nodes @@ -150,34 +150,27 @@ async def dedupe_extracted_nodes( uuid_map = {} for duplicate in duplicate_data: - uuid = node_map[duplicate["name"]].uuid + uuid = new_nodes_map[duplicate["name"]].uuid uuid_value = node_map[duplicate["duplicate_of"]].uuid uuid_map[uuid] = uuid_value - adjusted_nodes = [] + nodes = [] + brand_new_nodes = [] for node in extracted_nodes: if node.uuid in uuid_map: - existing_name = uuid_map[node.name] - existing_node = node_map[existing_name] + existing_uuid = uuid_map[node.uuid] + # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, + # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? + # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) + existing_node = next( + (v for k, v in node_map.items() if v.uuid == existing_uuid), None + ) nodes.append(existing_node) continue - adjusted_nodes.append(node) + brand_new_nodes.append(node) + nodes.append(node) - brand_new_nodes = [] - for node in new_nodes_data: - if node["name"] in brand_new_nodes_map.keys(): - brand_new_nodes.append(brand_new_nodes_map[node["name"]]) - logger.info(f"Brand new nodes: {[(n.name, n.uuid) for n in brand_new_nodes]}") - - adjusted_existing_nodes = [] - for node in new_nodes_data: - if node["name"] not in brand_new_nodes_map.keys(): - adjusted_existing_nodes.append(node_map[node["name"]]) - logger.info( - f"Adjusted existing nodes: {[(n.name, n.uuid) for n in adjusted_existing_nodes]}" - ) - return adjusted_nodes, (adjusted_existing_nodes, brand_new_nodes) - return nodes, uuid_map + return nodes, uuid_map, brand_new_nodes async def dedupe_node_list( diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index 8f555b0a..8353991c 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -50,9 +50,11 @@ async def invalidate_edges( context = prepare_invalidation_context( existing_edges_pending_invalidation, new_edges ) + logger.info(prompt_library.invalidate_edges.v1(context)) llm_response = await llm_client.generate_response( prompt_library.invalidate_edges.v1(context) ) + logger.info(f"invalidate_edges LLM response: {llm_response}") edges_to_invalidate = llm_response.get("invalidated_edges", []) invalidated_edges = process_edge_invalidation_llm_response( diff --git a/core/utils/search/search_utils.py b/core/utils/search/search_utils.py index 110b7a21..15cc5239 100644 --- a/core/utils/search/search_utils.py +++ b/core/utils/search/search_utils.py @@ -3,7 +3,7 @@ from datetime import datetime from time import time -from neo4j import AsyncDriver +from neo4j import AsyncDriver, time as neo4j_time from core.edges import EntityEdge from core.nodes import EntityNode @@ -91,8 +91,6 @@ async def edge_similarity_search( edges: list[EntityEdge] = [] - now = datetime.now() - for record in records: edge = EntityEdge( uuid=record["uuid"], @@ -102,10 +100,10 @@ async def edge_similarity_search( name=record["name"], episodes=record["episodes"], fact_embedding=record["fact_embedding"], - created_at=now, - expired_at=now, - valid_at=now, - invalid_At=now, + created_at=safely_parse_db_date(record["created_at"]), + expired_at=safely_parse_db_date(record["expired_at"]), + valid_at=safely_parse_db_date(record["valid_at"]), + invalid_At=safely_parse_db_date(record["invalid_at"]), ) edges.append(edge) @@ -139,7 +137,7 @@ async def entity_similarity_search( uuid=record["uuid"], name=record["name"], labels=[], - created_at=datetime.now(), + created_at=safely_parse_db_date(record["created_at"]), summary=record["summary"], ) ) @@ -174,7 +172,7 @@ async def entity_fulltext_search( uuid=record["uuid"], name=record["name"], labels=[], - created_at=datetime.now(), + created_at=safely_parse_db_date(record["created_at"]), summary=record["summary"], ) ) @@ -213,8 +211,6 @@ async def edge_fulltext_search( edges: list[EntityEdge] = [] - now = datetime.now() - for record in records: edge = EntityEdge( uuid=record["uuid"], @@ -224,10 +220,10 @@ async def edge_fulltext_search( name=record["name"], episodes=record["episodes"], fact_embedding=record["fact_embedding"], - created_at=now, - expired_at=now, - valid_at=now, - invalid_At=now, + created_at=safely_parse_db_date(record["created_at"]), + expired_at=safely_parse_db_date(record["expired_at"]), + valid_at=safely_parse_db_date(record["valid_at"]), + invalid_At=safely_parse_db_date(record["invalid_at"]), ) edges.append(edge) @@ -235,6 +231,12 @@ async def edge_fulltext_search( return edges +def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime: + if date_str: + return datetime.fromisoformat(date_str.iso_format()) + return None + + async def get_relevant_nodes( nodes: list[EntityNode], driver: AsyncDriver, diff --git a/runner.py b/runner.py index a9f19bc6..ee4ea305 100644 --- a/runner.py +++ b/runner.py @@ -5,6 +5,7 @@ import asyncio import logging import sys +from datetime import datetime load_dotenv() @@ -43,24 +44,28 @@ async def main(): # await client.build_indices() await client.add_episode( - name="Message 1", - episode_body="Paul: I love apples", + name="Message 3", + episode_body="Jane: I am married to Paul", source_description="WhatsApp Message", + reference_time=datetime.now(), ) await client.add_episode( - name="Message 2", - episode_body="Paul: I hate apples now", + name="Message 4", + episode_body="Paul: I have divorced Jane", source_description="WhatsApp Message", + reference_time=datetime.now(), ) await client.add_episode( - name="Message 3", - episode_body="Jane: I am married to Paul", + name="Message 5", + episode_body="Jane: I still love Paul", source_description="WhatsApp Message", + reference_time=datetime.now(), ) await client.add_episode( - name="Message 4", + name="Message 6", episode_body="Paul: I have divorced Jane", source_description="WhatsApp Message", + reference_time=datetime.now(), ) # await client.add_episode( From 7eb836eb9a9a3f06c30492fe325129027c4b9e26 Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 16:22:37 -0400 Subject: [PATCH 03/10] wip --- core/graphiti.py | 19 ++++++- core/prompts/dedupe_edges.py | 45 +++++++++++++++- core/prompts/invalidate_edges.py | 22 ++++---- core/utils/maintenance/edge_operations.py | 53 +++++++++++++++++-- core/utils/maintenance/temporal_operations.py | 47 ++++++++++++---- runner.py | 4 +- 6 files changed, 162 insertions(+), 28 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index 86c64329..2cb252df 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -24,12 +24,17 @@ resolve_edge_pointers, dedupe_edges_bulk, ) -from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges +from core.utils.maintenance.edge_operations import ( + extract_edges, + dedupe_extracted_edges_v2, + dedupe_extracted_edges, +) from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( invalidate_edges, prepare_edges_for_invalidation, + extract_node_and_edge_triplets, ) from core.utils.search.search_utils import ( edge_similarity_search, @@ -154,8 +159,16 @@ async def add_episode( f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}" ) + # deduped_edges = await dedupe_extracted_edges_v2( + # self.llm_client, + # extract_node_and_edge_triplets(extracted_edges, nodes), + # extract_node_and_edge_triplets(existing_edges, nodes), + # ) + deduped_edges = await dedupe_extracted_edges( - self.llm_client, extracted_edges, existing_edges + self.llm_client, + extracted_edges, + existing_edges, ) edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] @@ -174,6 +187,8 @@ async def add_episode( self.llm_client, old_edges_with_nodes_pending_invalidation, new_edges_with_nodes, + episode, + previous_episodes, ) for edge in invalidated_edges: diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py index a6d92a5a..d70bb677 100644 --- a/core/prompts/dedupe_edges.py +++ b/core/prompts/dedupe_edges.py @@ -6,11 +6,13 @@ class Prompt(Protocol): v1: PromptVersion + v2: PromptVersion edge_list: PromptVersion class Versions(TypedDict): v1: PromptFunction + v2: PromptFunction edge_list: PromptFunction @@ -40,12 +42,53 @@ def v1(context: dict[str, any]) -> list[Message]: Guidelines: 1. Use both the name and fact of edges to determine if they are duplicates, duplicate edges may have different names + + Respond with a JSON object in the following format: + {{ + "new_edges": [ + {{ + "fact": "one sentence description of the fact" + }} + ] + }} + """, + ), + ] + + +def v2(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that de-duplicates relationship from edge lists.", + ), + Message( + role="user", + content=f""" + Given the following context, deduplicate edges from a list of new edges given a list of existing edges: + + Existing Edges: + {json.dumps(context['existing_edges'], indent=2)} + + New Edges: + {json.dumps(context['extracted_edges'], indent=2)} + + Task: + 1. start with the list of edges from New Edges + 2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing + edge in the list + 3. Respond with the resulting list of edges + + Guidelines: + 1. Use both the triplet name and fact of edges to determine if they are duplicates, + duplicate edges may have different names meaning the same thing and slight variations in the facts. 2. If you encounter facts that are semantically equivalent or very similar, keep the original edge Respond with a JSON object in the following format: {{ "new_edges": [ {{ + "triplet": "source_node_name-edge_name-target_node_name", "fact": "one sentence description of the fact" }} ] @@ -91,4 +134,4 @@ def edge_list(context: dict[str, any]) -> list[Message]: ] -versions: Versions = {"v1": v1, "edge_list": edge_list} +versions: Versions = {"v1": v1, "v2": v2, "edge_list": edge_list} diff --git a/core/prompts/invalidate_edges.py b/core/prompts/invalidate_edges.py index 7a048969..1339a306 100644 --- a/core/prompts/invalidate_edges.py +++ b/core/prompts/invalidate_edges.py @@ -19,15 +19,15 @@ def v1(context: dict[str, any]) -> list[Message]: Message( role="user", content=f""" - Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to explicit contradictions in the new edges. - - Important guidelines: - 1. Only mark a relationship as invalid if there is an explicit, direct contradiction in the new edges. - 2. Do not make any assumptions or inferences about relationships. - 3. Do not invalidate edges based on implied changes or personal interpretations. - 4. A new edge does not automatically invalidate an existing edge unless it directly states the opposite. - 5. Different types of relationships can coexist and do not automatically invalidate each other. - 6. Do not invalidate relationships merely because they weren't mentioned in new edges. + Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges. + Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true. + Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships. + + Previous Episodes: + {context['previous_episodes']} + + Current Episode: + {context['current_episode']} Existing Edges (sorted by timestamp, newest first): {context['existing_edges']} @@ -35,14 +35,14 @@ def v1(context: dict[str, any]) -> list[Message]: New Edges: {context['new_edges']} - Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)" + Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), TIMESTAMP)" For each existing edge that should be invalidated, respond with a JSON object in the following format: {{ "invalidated_edges": [ {{ "edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)", - "reason": "Brief explanation citing the specific new edge that directly contradicts this edge" + "fact": "Updated fact of the edge" }} ] }} diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index dd6cd452..eca56582 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -8,7 +8,7 @@ from core.nodes import EntityNode, EpisodicNode from core.edges import EpisodicEdge, EntityEdge import logging - +from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet from core.prompts import prompt_library from core.llm_client import LLMClient @@ -196,6 +196,53 @@ async def extract_edges( return edges +def create_edge_identifier( + source_node: EntityNode, edge: EntityEdge, target_node: EntityNode +) -> str: + return f"{source_node.name}-{edge.name}-{target_node.name}" + + +async def dedupe_extracted_edges_v2( + llm_client: LLMClient, + extracted_edges: list[NodeEdgeNodeTriplet], + existing_edges: list[NodeEdgeNodeTriplet], +) -> list[NodeEdgeNodeTriplet]: + # Create edge map + edge_map = {} + for n1, edge, n2 in existing_edges: + edge_map[create_edge_identifier(n1, edge, n2)] = edge + for n1, edge, n2 in extracted_edges: + if create_edge_identifier(n1, edge, n2) in edge_map.keys(): + continue + edge_map[create_edge_identifier(n1, edge, n2)] = edge + + # Prepare context for LLM + context = { + "extracted_edges": [ + {"triplet": create_edge_identifier(n1, edge, n2), "fact": edge.fact} + for n1, edge, n2 in extracted_edges + ], + "existing_edges": [ + {"triplet": create_edge_identifier(n1, edge, n2), "fact": edge.fact} + for n1, edge, n2 in extracted_edges + ], + } + logger.info(prompt_library.dedupe_edges.v2(context)) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_edges.v2(context) + ) + new_edges_data = llm_response.get("new_edges", []) + logger.info(f"Extracted new edges: {new_edges_data}") + + # Get full edge data + edges = [] + for edge_data in new_edges_data: + edge = edge_map[edge_data["triplet"]] + edges.append(edge) + + return edges + + async def dedupe_extracted_edges( llm_client: LLMClient, extracted_edges: list[EntityEdge], @@ -206,7 +253,7 @@ async def dedupe_extracted_edges( for edge in existing_edges: edge_map[edge.fact] = edge for edge in extracted_edges: - if edge.fact in edge_map.keys(): + if edge.fact in edge_map: continue edge_map[edge.fact] = edge @@ -219,7 +266,7 @@ async def dedupe_extracted_edges( {"name": edge.name, "fact": edge.fact} for edge in extracted_edges ], } - logger.info(prompt_library.dedupe_edges.v1(context)) + llm_response = await llm_client.generate_response( prompt_library.dedupe_edges.v1(context) ) diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index 8353991c..acc6c46d 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -2,7 +2,7 @@ from typing import List from core.llm_client import LLMClient from core.edges import EntityEdge -from core.nodes import EntityNode +from core.nodes import EntityNode, EpisodicNode from core.prompts import prompt_library import logging @@ -11,6 +11,24 @@ NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode] +def extract_node_and_edge_triplets( + edges: list[EntityEdge], nodes: list[EntityNode] +) -> list[NodeEdgeNodeTriplet]: + return [extract_node_edge_node_triplet(edge, nodes) for edge in edges] + + +def extract_node_edge_node_triplet( + edge: EntityEdge, nodes: list[EntityNode] +) -> NodeEdgeNodeTriplet: + source_node = next( + (node for node in nodes if node.uuid == edge.source_node_uuid), None + ) + target_node = next( + (node for node in nodes if node.uuid == edge.target_node_uuid), None + ) + return (source_node, edge, target_node) + + def prepare_edges_for_invalidation( existing_edges: list[EntityEdge], new_edges: list[EntityEdge], @@ -42,13 +60,18 @@ def prepare_edges_for_invalidation( async def invalidate_edges( llm_client: LLMClient, - existing_edges_pending_invalidation: List[NodeEdgeNodeTriplet], - new_edges: List[NodeEdgeNodeTriplet], -) -> List[EntityEdge]: + existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], +) -> list[EntityEdge]: invalidated_edges = [] context = prepare_invalidation_context( - existing_edges_pending_invalidation, new_edges + existing_edges_pending_invalidation, + new_edges, + current_episode, + previous_episodes, ) logger.info(prompt_library.invalidate_edges.v1(context)) llm_response = await llm_client.generate_response( @@ -65,21 +88,26 @@ async def invalidate_edges( def prepare_invalidation_context( - existing_edges: List[NodeEdgeNodeTriplet], new_edges: List[NodeEdgeNodeTriplet] + existing_edges: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> dict: return { "existing_edges": [ - f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})" + f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" for source_node, edge, target_node in sorted( existing_edges, key=lambda x: x[1].created_at, reverse=True ) ], "new_edges": [ - f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})" + f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" for source_node, edge, target_node in sorted( new_edges, key=lambda x: x[1].created_at, reverse=True ) ], + "current_episode": current_episode.content, + "previous_episodes": [episode.content for episode in previous_episodes], } @@ -95,8 +123,9 @@ def process_edge_invalidation_llm_response( ) if edge_to_update: edge_to_update.expired_at = datetime.now() + edge_to_update.fact = edge_to_invalidate["fact"] invalidated_edges.append(edge_to_update) logger.info( - f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Reason: {edge_to_invalidate['reason']}" + f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" ) return invalidated_edges diff --git a/runner.py b/runner.py index ee4ea305..17aa775a 100644 --- a/runner.py +++ b/runner.py @@ -57,13 +57,13 @@ async def main(): ) await client.add_episode( name="Message 5", - episode_body="Jane: I still love Paul", + episode_body="Jane: I miss Paul", source_description="WhatsApp Message", reference_time=datetime.now(), ) await client.add_episode( name="Message 6", - episode_body="Paul: I have divorced Jane", + episode_body="Jane: I dont miss Paul anymore, I hate him", source_description="WhatsApp Message", reference_time=datetime.now(), ) From 80612c4e1395885ac55f9bc37516e2520960bdae Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 17:48:08 -0400 Subject: [PATCH 04/10] fix: Linter errors --- core/graphiti.py | 465 +++++++++--------- core/utils/maintenance/edge_operations.py | 5 +- core/utils/maintenance/node_operations.py | 295 +++++------ core/utils/maintenance/temporal_operations.py | 106 ++-- core/utils/search/search_utils.py | 5 +- runner.py | 40 +- 6 files changed, 474 insertions(+), 442 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index dde47918..28cb92cf 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -13,37 +13,33 @@ from core.nodes import EntityNode, EpisodicNode from core.search.search import SearchConfig, hybrid_search from core.search.search_utils import ( - get_relevant_edges, - get_relevant_nodes, + get_relevant_edges, + get_relevant_nodes, ) from core.utils import ( - build_episodic_edges, - retrieve_episodes, + build_episodic_edges, + retrieve_episodes, ) from core.utils.bulk_utils import ( - BulkEpisode, - dedupe_edges_bulk, - dedupe_nodes_bulk, - extract_nodes_and_edges_bulk, - resolve_edge_pointers, - retrieve_previous_episodes_bulk, -) -from core.utils.maintenance.edge_operations import dedupe_extracted_edges, extract_edges -from core.utils.maintenance.graph_data_operations import ( - EPISODE_WINDOW_LEN, - build_indices_and_constraints, + BulkEpisode, + dedupe_edges_bulk, + dedupe_nodes_bulk, + extract_nodes_and_edges_bulk, + resolve_edge_pointers, + retrieve_previous_episodes_bulk, ) from core.utils.maintenance.edge_operations import ( - extract_edges, - dedupe_extracted_edges_v2, dedupe_extracted_edges, + extract_edges, +) +from core.utils.maintenance.graph_data_operations import ( + EPISODE_WINDOW_LEN, + build_indices_and_constraints, ) -from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( - invalidate_edges, - prepare_edges_for_invalidation, - extract_node_and_edge_triplets, + invalidate_edges, + prepare_edges_for_invalidation, ) logger = logging.getLogger(__name__) @@ -52,82 +48,86 @@ class Graphiti: - def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): - self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) - self.database = 'neo4j' - if llm_client: - self.llm_client = llm_client - else: - self.llm_client = OpenAIClient( - LLMConfig( - api_key=os.getenv('OPENAI_API_KEY'), - model='gpt-4o-mini', - base_url='https://api.openai.com/v1', - ) - ) - - def close(self): - self.driver.close() - - async def build_indices_and_constraints(self): - await build_indices_and_constraints(self.driver) - - async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - sources: list[str] | None = 'messages', - ) -> list[EpisodicNode]: - """Retrieve the last n episodic nodes from the graph""" - return await retrieve_episodes(self.driver, reference_time, last_n, sources) - - # Invalidate edges that are no longer valid - async def invalidate_edges( - self, - episode: EpisodicNode, - new_nodes: list[EntityNode], - new_edges: list[EntityEdge], - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], - ): ... - - async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime | None = None, - episode_type: str | None = 'string', # TODO: this field isn't used yet? - success_callback: Callable | None = None, - error_callback: Callable | None = None, - ): - """Process an episode and update the graph""" - try: - start = time() - - nodes: list[EntityNode] = [] - entity_edges: list[EntityEdge] = [] - episodic_edges: list[EpisodicEdge] = [] - embedder = self.llm_client.client.embeddings - now = datetime.now() - - previous_episodes = await self.retrieve_episodes(reference_time) - episode = EpisodicNode( - name=name, - labels=[], - source='messages', - content=episode_body, - source_description=source_description, - created_at=now, - valid_at=reference_time, - ) - - extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) + def __init__( + self, uri: str, user: str, password: str, llm_client: LLMClient | None = None + ): + self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) + self.database = "neo4j" + if llm_client: + self.llm_client = llm_client + else: + self.llm_client = OpenAIClient( + LLMConfig( + api_key=os.getenv("OPENAI_API_KEY"), + model="gpt-4o-mini", + base_url="https://api.openai.com/v1", + ) + ) + + def close(self): + self.driver.close() + + async def build_indices_and_constraints(self): + await build_indices_and_constraints(self.driver) + + async def retrieve_episodes( + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + sources: list[str] | None = "messages", + ) -> list[EpisodicNode]: + """Retrieve the last n episodic nodes from the graph""" + return await retrieve_episodes(self.driver, reference_time, last_n, sources) + + # Invalidate edges that are no longer valid + async def invalidate_edges( + self, + episode: EpisodicNode, + new_nodes: list[EntityNode], + new_edges: list[EntityEdge], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ): ... + + async def add_episode( + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime | None = None, + episode_type: str | None = "string", # TODO: this field isn't used yet? + success_callback: Callable | None = None, + error_callback: Callable | None = None, + ): + """Process an episode and update the graph""" + try: + start = time() + + nodes: list[EntityNode] = [] + entity_edges: list[EntityEdge] = [] + episodic_edges: list[EpisodicEdge] = [] + embedder = self.llm_client.client.embeddings + now = datetime.now() + + previous_episodes = await self.retrieve_episodes(reference_time) + episode = EpisodicNode( + name=name, + labels=[], + source="messages", + content=episode_body, + source_description=source_description, + created_at=now, + valid_at=reference_time, + ) + + extracted_nodes = await extract_nodes( + self.llm_client, episode, previous_episodes + ) logger.info( f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" ) - # Calculate Embeddings + # Calculate Embeddings await asyncio.gather( *[node.generate_name_embedding(embedder) for node in extracted_nodes] @@ -144,17 +144,21 @@ async def add_episode( ) nodes.extend(touched_nodes) - extracted_edges = await extract_edges( - self.llm_client, episode, touched_nodes, previous_episodes - ) + extracted_edges = await extract_edges( + self.llm_client, episode, touched_nodes, previous_episodes + ) - await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) + await asyncio.gather( + *[edge.generate_embedding(embedder) for edge in extracted_edges] + ) - existing_edges = await get_relevant_edges(extracted_edges, self.driver) - logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}') - logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') + existing_edges = await get_relevant_edges(extracted_edges, self.driver) + logger.info(f"Existing edges: {[(e.name, e.uuid) for e in existing_edges]}") + logger.info( + f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}" + ) - # deduped_edges = await dedupe_extracted_edges_v2( + # deduped_edges = await dedupe_extracted_edges_v2( # self.llm_client, # extract_node_and_edge_triplets(extracted_edges, nodes), # extract_node_and_edge_triplets(existing_edges, nodes), @@ -171,12 +175,12 @@ async def add_episode( edge_touched_node_uuids.append(edge.source_node_uuid) edge_touched_node_uuids.append(edge.target_node_uuid) - ( - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, - ) = prepare_edges_for_invalidation( - existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes - ) + ( + old_edges_with_nodes_pending_invalidation, + new_edges_with_nodes, + ) = prepare_edges_for_invalidation( + existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes + ) invalidated_edges = await invalidate_edges( self.llm_client, @@ -190,9 +194,9 @@ async def add_episode( edge_touched_node_uuids.append(edge.source_node_uuid) edge_touched_node_uuids.append(edge.target_node_uuid) - entity_edges.extend(invalidated_edges) + entity_edges.extend(invalidated_edges) - edge_touched_node_uuids = list(set(edge_touched_node_uuids)) + edge_touched_node_uuids = list(set(edge_touched_node_uuids)) involved_nodes = [ node for node in nodes if node.uuid in edge_touched_node_uuids ] @@ -202,10 +206,11 @@ async def add_episode( ) logger.info( - f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') + f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}" + ) - logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') - entity_edges.extend(deduped_edges) + logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}") + entity_edges.extend(deduped_edges) episodic_edges.extend( build_episodic_edges( @@ -218,117 +223,125 @@ async def add_episode( # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built logger.info(f"Built episodic edges: {episodic_edges}") - # invalidated_edges = await self.invalidate_edges( - # episode, new_nodes, new_edges, relevant_schema, previous_episodes - # ) - - # edges.extend(invalidated_edges) - - # Future optimization would be using batch operations to save nodes and edges - await episode.save(self.driver) - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) - await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) - - end = time() - logger.info(f'Completed add_episode in {(end-start) * 1000} ms') - # for node in nodes: - # if isinstance(node, EntityNode): - # await node.update_summary(self.driver) - if success_callback: - await success_callback(episode) - except Exception as e: - if error_callback: - await error_callback(episode, e) - else: - raise e - - async def add_episode_bulk( - self, - bulk_episodes: list[BulkEpisode], - ): - try: - start = time() - embedder = self.llm_client.client.embeddings - now = datetime.now() - - episodes = [ - EpisodicNode( - name=episode.name, - labels=[], - source='messages', - content=episode.content, - source_description=episode.source_description, - created_at=now, - valid_at=episode.reference_time, - ) - for episode in bulk_episodes - ] - - # Save all the episodes - await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) - - # Get previous episode context for each episode - episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) - - # Extract all nodes and edges - ( - extracted_nodes, - extracted_edges, - episodic_edges, - ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) - - # Generate embeddings - await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes], - *[edge.generate_embedding(embedder) for edge in extracted_edges], - ) - - # Dedupe extracted nodes - nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes) - - # save nodes to KG - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - - # re-map edge pointers so that they don't point to discard dupe nodes - extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map) - episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map) - - # save episodic edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) - - # Dedupe extracted edges - edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges) - logger.info(f'extracted edge length: {len(edges)}') - - # invalidate edges - - # save edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in edges]) - - end = time() - logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms') - - except Exception as e: - raise e - - async def search(self, query: str, num_results=10): - search_config = SearchConfig(num_episodes=0, num_results=num_results) - edges = ( - await hybrid_search( - self.driver, - self.llm_client.client.embeddings, - query, - datetime.now(), - search_config, - ) - )['edges'] - - facts = [edge.fact for edge in edges] - - return facts - - async def _search(self, query: str, timestamp: datetime, config: SearchConfig): - return await hybrid_search( - self.driver, self.llm_client.client.embeddings, query, timestamp, config - ) + # invalidated_edges = await self.invalidate_edges( + # episode, new_nodes, new_edges, relevant_schema, previous_episodes + # ) + + # edges.extend(invalidated_edges) + + # Future optimization would be using batch operations to save nodes and edges + await episode.save(self.driver) + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) + + end = time() + logger.info(f"Completed add_episode in {(end-start) * 1000} ms") + # for node in nodes: + # if isinstance(node, EntityNode): + # await node.update_summary(self.driver) + if success_callback: + await success_callback(episode) + except Exception as e: + if error_callback: + await error_callback(episode, e) + else: + raise e + + async def add_episode_bulk( + self, + bulk_episodes: list[BulkEpisode], + ): + try: + start = time() + embedder = self.llm_client.client.embeddings + now = datetime.now() + + episodes = [ + EpisodicNode( + name=episode.name, + labels=[], + source="messages", + content=episode.content, + source_description=episode.source_description, + created_at=now, + valid_at=episode.reference_time, + ) + for episode in bulk_episodes + ] + + # Save all the episodes + await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) + + # Get previous episode context for each episode + episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) + + # Extract all nodes and edges + ( + extracted_nodes, + extracted_edges, + episodic_edges, + ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) + + # Generate embeddings + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes], + *[edge.generate_embedding(embedder) for edge in extracted_edges], + ) + + # Dedupe extracted nodes + nodes, uuid_map = await dedupe_nodes_bulk( + self.driver, self.llm_client, extracted_nodes + ) + + # save nodes to KG + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + + # re-map edge pointers so that they don't point to discard dupe nodes + extracted_edges: list[EntityEdge] = resolve_edge_pointers( + extracted_edges, uuid_map + ) + episodic_edges: list[EpisodicEdge] = resolve_edge_pointers( + episodic_edges, uuid_map + ) + + # save episodic edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + + # Dedupe extracted edges + edges = await dedupe_edges_bulk( + self.driver, self.llm_client, extracted_edges + ) + logger.info(f"extracted edge length: {len(edges)}") + + # invalidate edges + + # save edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + + end = time() + logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms") + + except Exception as e: + raise e + + async def search(self, query: str, num_results=10): + search_config = SearchConfig(num_episodes=0, num_results=num_results) + edges = ( + await hybrid_search( + self.driver, + self.llm_client.client.embeddings, + query, + datetime.now(), + search_config, + ) + )["edges"] + + facts = [edge.fact for edge in edges] + + return facts + + async def _search(self, query: str, timestamp: datetime, config: SearchConfig): + return await hybrid_search( + self.driver, self.llm_client.client.embeddings, query, timestamp, config + ) diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index 614e9f38..dede92d4 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -7,9 +7,8 @@ from core.edges import EntityEdge, EpisodicEdge from core.llm_client import LLMClient from core.nodes import EntityNode, EpisodicNode -from core.edges import EpisodicEdge, EntityEdge -from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet from core.prompts import prompt_library +from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet logger = logging.getLogger(__name__) @@ -197,7 +196,7 @@ async def dedupe_extracted_edges_v2( for n1, edge, n2 in existing_edges: edge_map[create_edge_identifier(n1, edge, n2)] = edge for n1, edge, n2 in extracted_edges: - if create_edge_identifier(n1, edge, n2) in edge_map.keys(): + if create_edge_identifier(n1, edge, n2) in edge_map: continue edge_map[create_edge_identifier(n1, edge, n2)] = edge diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index 6c267701..ab8c34f0 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -10,192 +10,205 @@ async def extract_new_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: - # Prepare context for LLM - existing_nodes = [ - {'name': node_name, 'label': node_info['label'], 'uuid': node_info['uuid']} - for node_name, node_info in relevant_schema['nodes'].items() - ] - - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'existing_nodes': existing_nodes, - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v1(context)) - new_nodes_data = llm_response.get('new_nodes', []) - logger.info(f'Extracted new nodes: {new_nodes_data}') - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - # Check if the node already exists - if not any(existing_node['name'] == node_data['name'] for existing_node in existing_nodes): - new_node = EntityNode( - name=node_data['name'], - labels=node_data['labels'], - summary=node_data['summary'], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - else: - logger.info(f"Node {node_data['name']} already exists, skipping creation.") - - return new_nodes + # Prepare context for LLM + existing_nodes = [ + {"name": node_name, "label": node_info["label"], "uuid": node_info["uuid"]} + for node_name, node_info in relevant_schema["nodes"].items() + ] + + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "existing_nodes": existing_nodes, + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.v1(context) + ) + new_nodes_data = llm_response.get("new_nodes", []) + logger.info(f"Extracted new nodes: {new_nodes_data}") + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + # Check if the node already exists + if not any( + existing_node["name"] == node_data["name"] + for existing_node in existing_nodes + ): + new_node = EntityNode( + name=node_data["name"], + labels=node_data["labels"], + summary=node_data["summary"], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") + else: + logger.info(f"Node {node_data['name']} already exists, skipping creation.") + + return new_nodes async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: - start = time() - - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v3(context)) - new_nodes_data = llm_response.get('new_nodes', []) - - end = time() - logger.info(f'Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms') - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - new_node = EntityNode( - name=node_data['name'], - labels=node_data['labels'], - summary=node_data['summary'], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - - return new_nodes + start = time() + + # Prepare context for LLM + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.v3(context) + ) + new_nodes_data = llm_response.get("new_nodes", []) + + end = time() + logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms") + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + new_node = EntityNode( + name=node_data["name"], + labels=node_data["labels"], + summary=node_data["summary"], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") + + return new_nodes async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - start = time() + start = time() - # build existing node map - node_map = {} - for node in existing_nodes: - node_map[node.name] = node + # build existing node map + node_map = {} + for node in existing_nodes: + node_map[node.name] = node - # Temp hack + # Temp hack new_nodes_map = {} for node in extracted_nodes: new_nodes_map[node.name] = node # Prepare context for LLM - existing_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in existing_nodes - ] + existing_nodes_context = [ + {"name": node.name, "summary": node.summary} for node in existing_nodes + ] - extracted_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in extracted_nodes - ] + extracted_nodes_context = [ + {"name": node.name, "summary": node.summary} for node in extracted_nodes + ] - context = { - 'existing_nodes': existing_nodes_context, - 'extracted_nodes': extracted_nodes_context, - } + context = { + "existing_nodes": existing_nodes_context, + "extracted_nodes": extracted_nodes_context, + } - llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v2(context)) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.v2(context) + ) - duplicate_data = llm_response.get('duplicates', []) + duplicate_data = llm_response.get("duplicates", []) - end = time() - logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms') + end = time() + logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms") - uuid_map = {} - for duplicate in duplicate_data: - uuid = new_nodes_map[duplicate['name']].uuid - uuid_value = node_map[duplicate['duplicate_of']].uuid - uuid_map[uuid] = uuid_value + uuid_map = {} + for duplicate in duplicate_data: + uuid = new_nodes_map[duplicate["name"]].uuid + uuid_value = node_map[duplicate["duplicate_of"]].uuid + uuid_map[uuid] = uuid_value - nodes = [] - brand_new_nodes = [] + nodes = [] + brand_new_nodes = [] for node in extracted_nodes: if node.uuid in uuid_map: existing_uuid = uuid_map[node.uuid] # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) - existing_node = next( + existing_node = next( (v for k, v in node_map.items() if v.uuid == existing_uuid), None ) nodes.append(existing_node) continue brand_new_nodes.append(node) - nodes.append(node) + nodes.append(node) - return nodes, uuid_map, brand_new_nodes + return nodes, uuid_map, brand_new_nodes async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - start = time() + start = time() - # build node map - node_map = {} - for node in nodes: - node_map[node.name] = node + # build node map + node_map = {} + for node in nodes: + node_map[node.name] = node - # Prepare context for LLM - nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes] + # Prepare context for LLM + nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes] - context = { - 'nodes': nodes_context, - } + context = { + "nodes": nodes_context, + } - llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.node_list(context) - ) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.node_list(context) + ) - nodes_data = llm_response.get('nodes', []) + nodes_data = llm_response.get("nodes", []) - end = time() - logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms') + end = time() + logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms") - # Get full node data - unique_nodes = [] - uuid_map: dict[str, str] = {} - for node_data in nodes_data: - node = node_map[node_data['names'][0]] - unique_nodes.append(node) + # Get full node data + unique_nodes = [] + uuid_map: dict[str, str] = {} + for node_data in nodes_data: + node = node_map[node_data["names"][0]] + unique_nodes.append(node) - for name in node_data['names'][1:]: - uuid = node_map[name].uuid - uuid_value = node_map[node_data['names'][0]].uuid - uuid_map[uuid] = uuid_value + for name in node_data["names"][1:]: + uuid = node_map[name].uuid + uuid_value = node_map[node_data["names"][0]].uuid + uuid_map[uuid] = uuid_value - return unique_nodes, uuid_map + return unique_nodes, uuid_map diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index a89cb0b8..87b2fd40 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -31,38 +31,42 @@ def extract_node_edge_node_triplet( def prepare_edges_for_invalidation( - existing_edges: list[EntityEdge], - new_edges: list[EntityEdge], - nodes: list[EntityNode], + existing_edges: list[EntityEdge], + new_edges: list[EntityEdge], + nodes: list[EntityNode], ) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]: - existing_edges_pending_invalidation = [] # TODO: this is not yet used? - new_edges_with_nodes = [] # TODO: this is not yet used? - - existing_edges_pending_invalidation = [] - new_edges_with_nodes = [] - - for edge_list, result_list in [ - (existing_edges, existing_edges_pending_invalidation), - (new_edges, new_edges_with_nodes), - ]: - for edge in edge_list: - source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) - target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) + existing_edges_pending_invalidation = [] # TODO: this is not yet used? + new_edges_with_nodes = [] # TODO: this is not yet used? + + existing_edges_pending_invalidation = [] + new_edges_with_nodes = [] + + for edge_list, result_list in [ + (existing_edges, existing_edges_pending_invalidation), + (new_edges, new_edges_with_nodes), + ]: + for edge in edge_list: + source_node = next( + (node for node in nodes if node.uuid == edge.source_node_uuid), None + ) + target_node = next( + (node for node in nodes if node.uuid == edge.target_node_uuid), None + ) - if source_node and target_node: - result_list.append((source_node, edge, target_node)) + if source_node and target_node: + result_list.append((source_node, edge, target_node)) - return existing_edges_pending_invalidation, new_edges_with_nodes + return existing_edges_pending_invalidation, new_edges_with_nodes async def invalidate_edges( - llm_client: LLMClient, - existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], + llm_client: LLMClient, + existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], current_episode: EpisodicNode, previous_episodes: list[EpisodicNode], ) -> list[EntityEdge]: - invalidated_edges = [] # TODO: this is not yet used? + invalidated_edges = [] # TODO: this is not yet used? context = prepare_invalidation_context( existing_edges_pending_invalidation, @@ -76,29 +80,29 @@ async def invalidate_edges( ) logger.info(f"invalidate_edges LLM response: {llm_response}") - edges_to_invalidate = llm_response.get('invalidated_edges', []) - invalidated_edges = process_edge_invalidation_llm_response( - edges_to_invalidate, existing_edges_pending_invalidation - ) + edges_to_invalidate = llm_response.get("invalidated_edges", []) + invalidated_edges = process_edge_invalidation_llm_response( + edges_to_invalidate, existing_edges_pending_invalidation + ) - return invalidated_edges + return invalidated_edges def prepare_invalidation_context( - existing_edges: list[NodeEdgeNodeTriplet], + existing_edges: list[NodeEdgeNodeTriplet], new_edges: list[NodeEdgeNodeTriplet], current_episode: EpisodicNode, previous_episodes: list[EpisodicNode], ) -> dict: - return { - 'existing_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' - for source_node, edge, target_node in sorted( - existing_edges, key=lambda x: x[1].created_at, reverse=True - ) - ], - 'new_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' + return { + "existing_edges": [ + f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" + for source_node, edge, target_node in sorted( + existing_edges, key=lambda x: x[1].created_at, reverse=True + ) + ], + "new_edges": [ + f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" for source_node, edge, target_node in sorted( new_edges, key=lambda x: x[1].created_at, reverse=True ) @@ -109,20 +113,20 @@ def prepare_invalidation_context( def process_edge_invalidation_llm_response( - edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] + edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] ) -> List[EntityEdge]: - invalidated_edges = [] - for edge_to_invalidate in edges_to_invalidate: - edge_uuid = edge_to_invalidate['edge_uuid'] - edge_to_update = next( - (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), - None, - ) - if edge_to_update: - edge_to_update.expired_at = datetime.now() - edge_to_update.fact = edge_to_invalidate["fact"] + invalidated_edges = [] + for edge_to_invalidate in edges_to_invalidate: + edge_uuid = edge_to_invalidate["edge_uuid"] + edge_to_update = next( + (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), + None, + ) + if edge_to_update: + edge_to_update.expired_at = datetime.now() + edge_to_update.fact = edge_to_invalidate["fact"] invalidated_edges.append(edge_to_update) logger.info( f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" - ) - return invalidated_edges + ) + return invalidated_edges diff --git a/core/utils/search/search_utils.py b/core/utils/search/search_utils.py index 15cc5239..dcd36281 100644 --- a/core/utils/search/search_utils.py +++ b/core/utils/search/search_utils.py @@ -3,7 +3,8 @@ from datetime import datetime from time import time -from neo4j import AsyncDriver, time as neo4j_time +from neo4j import AsyncDriver +from neo4j import time as neo4j_time from core.edges import EntityEdge from core.nodes import EntityNode @@ -42,7 +43,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): for record in records: n_uuid = record["source_node_uuid"] - if n_uuid in context.keys(): + if n_uuid in context: context[n_uuid]["facts"].append(record["fact"]) else: context[n_uuid] = { diff --git a/runner.py b/runner.py index 63804013..5de65d0e 100644 --- a/runner.py +++ b/runner.py @@ -11,36 +11,38 @@ load_dotenv() -neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' -neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' -neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' +neo4j_uri = os.environ.get("NEO4J_URI") or "bolt://localhost:7687" +neo4j_user = os.environ.get("NEO4J_USER") or "neo4j" +neo4j_password = os.environ.get("NEO4J_PASSWORD") or "password" def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) - # Add formatter to console handler - console_handler.setFormatter(formatter) + # Add formatter to console handler + console_handler.setFormatter(formatter) - # Add console handler to logger - logger.addHandler(console_handler) + # Add console handler to logger + logger.addHandler(console_handler) - return logger + return logger async def main(): - setup_logging() - client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - await clear_data(client.driver) + setup_logging() + client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + await clear_data(client.driver) # await client.build_indices() await client.add_episode( From 03ad8cf045f88410b48273f832c5445b73d9475d Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Thu, 22 Aug 2024 14:52:04 -0700 Subject: [PATCH 05/10] fix formatting --- core/prompts/dedupe_edges.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py index 0c750dce..b3b1069a 100644 --- a/core/prompts/dedupe_edges.py +++ b/core/prompts/dedupe_edges.py @@ -6,13 +6,13 @@ class Prompt(Protocol): v1: PromptVersion - v2: PromptVersion + v2: PromptVersion class Versions(TypedDict): v1: PromptFunction v2: PromptFunction - edge_list: PromptFunction + edge_list: PromptFunction def v1(context: dict[str, any]) -> list[Message]: @@ -56,14 +56,14 @@ def v1(context: dict[str, any]) -> list[Message]: def v2(context: dict[str, any]) -> list[Message]: - return [ - Message( - role="system", - content="You are a helpful assistant that de-duplicates relationship from edge lists.", - ), - Message( - role="user", - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates relationship from edge lists.', + ), + Message( + role='user', + content=f""" Given the following context, deduplicate edges from a list of new edges given a list of existing edges: Existing Edges: @@ -93,8 +93,8 @@ def v2(context: dict[str, any]) -> list[Message]: ] }} """, - ), - ] + ), + ] def edge_list(context: dict[str, any]) -> list[Message]: @@ -133,4 +133,4 @@ def edge_list(context: dict[str, any]) -> list[Message]: ] -versions: Versions = {"v1": v1, "v2": v2, "edge_list": edge_list} +versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list} From d561ded74553f3febd8254dfd350f5e77f6872c0 Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 17:52:43 -0400 Subject: [PATCH 06/10] chore: fix ruff --- core/graphiti.py | 598 +++++++++--------- core/prompts/dedupe_edges.py | 26 +- core/utils/maintenance/edge_operations.py | 76 ++- core/utils/maintenance/node_operations.py | 345 +++++----- core/utils/maintenance/temporal_operations.py | 168 +++-- core/utils/search/search_utils.py | 312 +++++---- runner.py | 114 ++-- 7 files changed, 789 insertions(+), 850 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index 28cb92cf..d52ef5b7 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -13,33 +13,33 @@ from core.nodes import EntityNode, EpisodicNode from core.search.search import SearchConfig, hybrid_search from core.search.search_utils import ( - get_relevant_edges, - get_relevant_nodes, + get_relevant_edges, + get_relevant_nodes, ) from core.utils import ( - build_episodic_edges, - retrieve_episodes, + build_episodic_edges, + retrieve_episodes, ) from core.utils.bulk_utils import ( - BulkEpisode, - dedupe_edges_bulk, - dedupe_nodes_bulk, - extract_nodes_and_edges_bulk, - resolve_edge_pointers, - retrieve_previous_episodes_bulk, + BulkEpisode, + dedupe_edges_bulk, + dedupe_nodes_bulk, + extract_nodes_and_edges_bulk, + resolve_edge_pointers, + retrieve_previous_episodes_bulk, ) from core.utils.maintenance.edge_operations import ( - dedupe_extracted_edges, - extract_edges, + dedupe_extracted_edges, + extract_edges, ) from core.utils.maintenance.graph_data_operations import ( - EPISODE_WINDOW_LEN, - build_indices_and_constraints, + EPISODE_WINDOW_LEN, + build_indices_and_constraints, ) from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( - invalidate_edges, - prepare_edges_for_invalidation, + invalidate_edges, + prepare_edges_for_invalidation, ) logger = logging.getLogger(__name__) @@ -48,300 +48,272 @@ class Graphiti: - def __init__( - self, uri: str, user: str, password: str, llm_client: LLMClient | None = None - ): - self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) - self.database = "neo4j" - if llm_client: - self.llm_client = llm_client - else: - self.llm_client = OpenAIClient( - LLMConfig( - api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4o-mini", - base_url="https://api.openai.com/v1", - ) - ) - - def close(self): - self.driver.close() - - async def build_indices_and_constraints(self): - await build_indices_and_constraints(self.driver) - - async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - sources: list[str] | None = "messages", - ) -> list[EpisodicNode]: - """Retrieve the last n episodic nodes from the graph""" - return await retrieve_episodes(self.driver, reference_time, last_n, sources) - - # Invalidate edges that are no longer valid - async def invalidate_edges( - self, - episode: EpisodicNode, - new_nodes: list[EntityNode], - new_edges: list[EntityEdge], - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], - ): ... - - async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime | None = None, - episode_type: str | None = "string", # TODO: this field isn't used yet? - success_callback: Callable | None = None, - error_callback: Callable | None = None, - ): - """Process an episode and update the graph""" - try: - start = time() - - nodes: list[EntityNode] = [] - entity_edges: list[EntityEdge] = [] - episodic_edges: list[EpisodicEdge] = [] - embedder = self.llm_client.client.embeddings - now = datetime.now() - - previous_episodes = await self.retrieve_episodes(reference_time) - episode = EpisodicNode( - name=name, - labels=[], - source="messages", - content=episode_body, - source_description=source_description, - created_at=now, - valid_at=reference_time, - ) - - extracted_nodes = await extract_nodes( - self.llm_client, episode, previous_episodes - ) - logger.info( - f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" - ) - - # Calculate Embeddings - - await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes] - ) - existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver) - logger.info( - f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" - ) - touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes( - self.llm_client, extracted_nodes, existing_nodes - ) - logger.info( - f"Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}" - ) - nodes.extend(touched_nodes) - - extracted_edges = await extract_edges( - self.llm_client, episode, touched_nodes, previous_episodes - ) - - await asyncio.gather( - *[edge.generate_embedding(embedder) for edge in extracted_edges] - ) - - existing_edges = await get_relevant_edges(extracted_edges, self.driver) - logger.info(f"Existing edges: {[(e.name, e.uuid) for e in existing_edges]}") - logger.info( - f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}" - ) - - # deduped_edges = await dedupe_extracted_edges_v2( - # self.llm_client, - # extract_node_and_edge_triplets(extracted_edges, nodes), - # extract_node_and_edge_triplets(existing_edges, nodes), - # ) - - deduped_edges = await dedupe_extracted_edges( - self.llm_client, - extracted_edges, - existing_edges, - ) - - edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] - for edge in deduped_edges: - edge_touched_node_uuids.append(edge.source_node_uuid) - edge_touched_node_uuids.append(edge.target_node_uuid) - - ( - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, - ) = prepare_edges_for_invalidation( - existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes - ) - - invalidated_edges = await invalidate_edges( - self.llm_client, - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, - episode, - previous_episodes, - ) - - for edge in invalidated_edges: - edge_touched_node_uuids.append(edge.source_node_uuid) - edge_touched_node_uuids.append(edge.target_node_uuid) - - entity_edges.extend(invalidated_edges) - - edge_touched_node_uuids = list(set(edge_touched_node_uuids)) - involved_nodes = [ - node for node in nodes if node.uuid in edge_touched_node_uuids - ] - - logger.info( - f"Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}" - ) - - logger.info( - f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}" - ) - - logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}") - entity_edges.extend(deduped_edges) - - episodic_edges.extend( - build_episodic_edges( - # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them - involved_nodes, - episode, - now, - ) - ) - # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built - logger.info(f"Built episodic edges: {episodic_edges}") - - # invalidated_edges = await self.invalidate_edges( - # episode, new_nodes, new_edges, relevant_schema, previous_episodes - # ) - - # edges.extend(invalidated_edges) - - # Future optimization would be using batch operations to save nodes and edges - await episode.save(self.driver) - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) - await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) - - end = time() - logger.info(f"Completed add_episode in {(end-start) * 1000} ms") - # for node in nodes: - # if isinstance(node, EntityNode): - # await node.update_summary(self.driver) - if success_callback: - await success_callback(episode) - except Exception as e: - if error_callback: - await error_callback(episode, e) - else: - raise e - - async def add_episode_bulk( - self, - bulk_episodes: list[BulkEpisode], - ): - try: - start = time() - embedder = self.llm_client.client.embeddings - now = datetime.now() - - episodes = [ - EpisodicNode( - name=episode.name, - labels=[], - source="messages", - content=episode.content, - source_description=episode.source_description, - created_at=now, - valid_at=episode.reference_time, - ) - for episode in bulk_episodes - ] - - # Save all the episodes - await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) - - # Get previous episode context for each episode - episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) - - # Extract all nodes and edges - ( - extracted_nodes, - extracted_edges, - episodic_edges, - ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) - - # Generate embeddings - await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes], - *[edge.generate_embedding(embedder) for edge in extracted_edges], - ) - - # Dedupe extracted nodes - nodes, uuid_map = await dedupe_nodes_bulk( - self.driver, self.llm_client, extracted_nodes - ) - - # save nodes to KG - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - - # re-map edge pointers so that they don't point to discard dupe nodes - extracted_edges: list[EntityEdge] = resolve_edge_pointers( - extracted_edges, uuid_map - ) - episodic_edges: list[EpisodicEdge] = resolve_edge_pointers( - episodic_edges, uuid_map - ) - - # save episodic edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) - - # Dedupe extracted edges - edges = await dedupe_edges_bulk( - self.driver, self.llm_client, extracted_edges - ) - logger.info(f"extracted edge length: {len(edges)}") - - # invalidate edges - - # save edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in edges]) - - end = time() - logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms") - - except Exception as e: - raise e - - async def search(self, query: str, num_results=10): - search_config = SearchConfig(num_episodes=0, num_results=num_results) - edges = ( - await hybrid_search( - self.driver, - self.llm_client.client.embeddings, - query, - datetime.now(), - search_config, - ) - )["edges"] - - facts = [edge.fact for edge in edges] - - return facts - - async def _search(self, query: str, timestamp: datetime, config: SearchConfig): - return await hybrid_search( - self.driver, self.llm_client.client.embeddings, query, timestamp, config - ) + def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): + self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) + self.database = 'neo4j' + if llm_client: + self.llm_client = llm_client + else: + self.llm_client = OpenAIClient( + LLMConfig( + api_key=os.getenv('OPENAI_API_KEY'), + model='gpt-4o-mini', + base_url='https://api.openai.com/v1', + ) + ) + + def close(self): + self.driver.close() + + async def build_indices_and_constraints(self): + await build_indices_and_constraints(self.driver) + + async def retrieve_episodes( + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + sources: list[str] | None = 'messages', + ) -> list[EpisodicNode]: + """Retrieve the last n episodic nodes from the graph""" + return await retrieve_episodes(self.driver, reference_time, last_n, sources) + + # Invalidate edges that are no longer valid + async def invalidate_edges( + self, + episode: EpisodicNode, + new_nodes: list[EntityNode], + new_edges: list[EntityEdge], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ): ... + + async def add_episode( + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime | None = None, + episode_type: str | None = 'string', # TODO: this field isn't used yet? + success_callback: Callable | None = None, + error_callback: Callable | None = None, + ): + """Process an episode and update the graph""" + try: + start = time() + + nodes: list[EntityNode] = [] + entity_edges: list[EntityEdge] = [] + episodic_edges: list[EpisodicEdge] = [] + embedder = self.llm_client.client.embeddings + now = datetime.now() + + previous_episodes = await self.retrieve_episodes(reference_time) + episode = EpisodicNode( + name=name, + labels=[], + source='messages', + content=episode_body, + source_description=source_description, + created_at=now, + valid_at=reference_time, + ) + + extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) + logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') + + # Calculate Embeddings + + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes] + ) + existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver) + logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') + touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes( + self.llm_client, extracted_nodes, existing_nodes + ) + logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}') + nodes.extend(touched_nodes) + + extracted_edges = await extract_edges( + self.llm_client, episode, touched_nodes, previous_episodes + ) + + await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) + + existing_edges = await get_relevant_edges(extracted_edges, self.driver) + logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}') + logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') + + # deduped_edges = await dedupe_extracted_edges_v2( + # self.llm_client, + # extract_node_and_edge_triplets(extracted_edges, nodes), + # extract_node_and_edge_triplets(existing_edges, nodes), + # ) + + deduped_edges = await dedupe_extracted_edges( + self.llm_client, + extracted_edges, + existing_edges, + ) + + edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] + for edge in deduped_edges: + edge_touched_node_uuids.append(edge.source_node_uuid) + edge_touched_node_uuids.append(edge.target_node_uuid) + + ( + old_edges_with_nodes_pending_invalidation, + new_edges_with_nodes, + ) = prepare_edges_for_invalidation( + existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes + ) + + invalidated_edges = await invalidate_edges( + self.llm_client, + old_edges_with_nodes_pending_invalidation, + new_edges_with_nodes, + episode, + previous_episodes, + ) + + for edge in invalidated_edges: + edge_touched_node_uuids.append(edge.source_node_uuid) + edge_touched_node_uuids.append(edge.target_node_uuid) + + entity_edges.extend(invalidated_edges) + + edge_touched_node_uuids = list(set(edge_touched_node_uuids)) + involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids] + + logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}') + + logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') + + logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') + entity_edges.extend(deduped_edges) + + episodic_edges.extend( + build_episodic_edges( + # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them + involved_nodes, + episode, + now, + ) + ) + # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built + logger.info(f'Built episodic edges: {episodic_edges}') + + # invalidated_edges = await self.invalidate_edges( + # episode, new_nodes, new_edges, relevant_schema, previous_episodes + # ) + + # edges.extend(invalidated_edges) + + # Future optimization would be using batch operations to save nodes and edges + await episode.save(self.driver) + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) + + end = time() + logger.info(f'Completed add_episode in {(end-start) * 1000} ms') + # for node in nodes: + # if isinstance(node, EntityNode): + # await node.update_summary(self.driver) + if success_callback: + await success_callback(episode) + except Exception as e: + if error_callback: + await error_callback(episode, e) + else: + raise e + + async def add_episode_bulk( + self, + bulk_episodes: list[BulkEpisode], + ): + try: + start = time() + embedder = self.llm_client.client.embeddings + now = datetime.now() + + episodes = [ + EpisodicNode( + name=episode.name, + labels=[], + source='messages', + content=episode.content, + source_description=episode.source_description, + created_at=now, + valid_at=episode.reference_time, + ) + for episode in bulk_episodes + ] + + # Save all the episodes + await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) + + # Get previous episode context for each episode + episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) + + # Extract all nodes and edges + ( + extracted_nodes, + extracted_edges, + episodic_edges, + ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) + + # Generate embeddings + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes], + *[edge.generate_embedding(embedder) for edge in extracted_edges], + ) + + # Dedupe extracted nodes + nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes) + + # save nodes to KG + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + + # re-map edge pointers so that they don't point to discard dupe nodes + extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map) + episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map) + + # save episodic edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + + # Dedupe extracted edges + edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges) + logger.info(f'extracted edge length: {len(edges)}') + + # invalidate edges + + # save edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + + end = time() + logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms') + + except Exception as e: + raise e + + async def search(self, query: str, num_results=10): + search_config = SearchConfig(num_episodes=0, num_results=num_results) + edges = ( + await hybrid_search( + self.driver, + self.llm_client.client.embeddings, + query, + datetime.now(), + search_config, + ) + )['edges'] + + facts = [edge.fact for edge in edges] + + return facts + + async def _search(self, query: str, timestamp: datetime, config: SearchConfig): + return await hybrid_search( + self.driver, self.llm_client.client.embeddings, query, timestamp, config + ) diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py index 0c750dce..b3b1069a 100644 --- a/core/prompts/dedupe_edges.py +++ b/core/prompts/dedupe_edges.py @@ -6,13 +6,13 @@ class Prompt(Protocol): v1: PromptVersion - v2: PromptVersion + v2: PromptVersion class Versions(TypedDict): v1: PromptFunction v2: PromptFunction - edge_list: PromptFunction + edge_list: PromptFunction def v1(context: dict[str, any]) -> list[Message]: @@ -56,14 +56,14 @@ def v1(context: dict[str, any]) -> list[Message]: def v2(context: dict[str, any]) -> list[Message]: - return [ - Message( - role="system", - content="You are a helpful assistant that de-duplicates relationship from edge lists.", - ), - Message( - role="user", - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates relationship from edge lists.', + ), + Message( + role='user', + content=f""" Given the following context, deduplicate edges from a list of new edges given a list of existing edges: Existing Edges: @@ -93,8 +93,8 @@ def v2(context: dict[str, any]) -> list[Message]: ] }} """, - ), - ] + ), + ] def edge_list(context: dict[str, any]) -> list[Message]: @@ -133,4 +133,4 @@ def edge_list(context: dict[str, any]) -> list[Message]: ] -versions: Versions = {"v1": v1, "v2": v2, "edge_list": edge_list} +versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list} diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index dede92d4..16e8c431 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -181,50 +181,48 @@ async def extract_edges( def create_edge_identifier( - source_node: EntityNode, edge: EntityEdge, target_node: EntityNode + source_node: EntityNode, edge: EntityEdge, target_node: EntityNode ) -> str: - return f"{source_node.name}-{edge.name}-{target_node.name}" + return f'{source_node.name}-{edge.name}-{target_node.name}' async def dedupe_extracted_edges_v2( - llm_client: LLMClient, - extracted_edges: list[NodeEdgeNodeTriplet], - existing_edges: list[NodeEdgeNodeTriplet], + llm_client: LLMClient, + extracted_edges: list[NodeEdgeNodeTriplet], + existing_edges: list[NodeEdgeNodeTriplet], ) -> list[NodeEdgeNodeTriplet]: - # Create edge map - edge_map = {} - for n1, edge, n2 in existing_edges: - edge_map[create_edge_identifier(n1, edge, n2)] = edge - for n1, edge, n2 in extracted_edges: - if create_edge_identifier(n1, edge, n2) in edge_map: - continue - edge_map[create_edge_identifier(n1, edge, n2)] = edge - - # Prepare context for LLM - context = { - "extracted_edges": [ - {"triplet": create_edge_identifier(n1, edge, n2), "fact": edge.fact} - for n1, edge, n2 in extracted_edges - ], - "existing_edges": [ - {"triplet": create_edge_identifier(n1, edge, n2), "fact": edge.fact} - for n1, edge, n2 in extracted_edges - ], - } - logger.info(prompt_library.dedupe_edges.v2(context)) - llm_response = await llm_client.generate_response( - prompt_library.dedupe_edges.v2(context) - ) - new_edges_data = llm_response.get("new_edges", []) - logger.info(f"Extracted new edges: {new_edges_data}") - - # Get full edge data - edges = [] - for edge_data in new_edges_data: - edge = edge_map[edge_data["triplet"]] - edges.append(edge) - - return edges + # Create edge map + edge_map = {} + for n1, edge, n2 in existing_edges: + edge_map[create_edge_identifier(n1, edge, n2)] = edge + for n1, edge, n2 in extracted_edges: + if create_edge_identifier(n1, edge, n2) in edge_map: + continue + edge_map[create_edge_identifier(n1, edge, n2)] = edge + + # Prepare context for LLM + context = { + 'extracted_edges': [ + {'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact} + for n1, edge, n2 in extracted_edges + ], + 'existing_edges': [ + {'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact} + for n1, edge, n2 in extracted_edges + ], + } + logger.info(prompt_library.dedupe_edges.v2(context)) + llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context)) + new_edges_data = llm_response.get('new_edges', []) + logger.info(f'Extracted new edges: {new_edges_data}') + + # Get full edge data + edges = [] + for edge_data in new_edges_data: + edge = edge_map[edge_data['triplet']] + edges.append(edge) + + return edges async def dedupe_extracted_edges( diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index ab8c34f0..b9ecab53 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -10,205 +10,190 @@ async def extract_new_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: - # Prepare context for LLM - existing_nodes = [ - {"name": node_name, "label": node_info["label"], "uuid": node_info["uuid"]} - for node_name, node_info in relevant_schema["nodes"].items() - ] - - context = { - "episode_content": episode.content, - "episode_timestamp": ( - episode.valid_at.isoformat() if episode.valid_at else None - ), - "existing_nodes": existing_nodes, - "previous_episodes": [ - { - "content": ep.content, - "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.v1(context) - ) - new_nodes_data = llm_response.get("new_nodes", []) - logger.info(f"Extracted new nodes: {new_nodes_data}") - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - # Check if the node already exists - if not any( - existing_node["name"] == node_data["name"] - for existing_node in existing_nodes - ): - new_node = EntityNode( - name=node_data["name"], - labels=node_data["labels"], - summary=node_data["summary"], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") - else: - logger.info(f"Node {node_data['name']} already exists, skipping creation.") - - return new_nodes + # Prepare context for LLM + existing_nodes = [ + {'name': node_name, 'label': node_info['label'], 'uuid': node_info['uuid']} + for node_name, node_info in relevant_schema['nodes'].items() + ] + + context = { + 'episode_content': episode.content, + 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), + 'existing_nodes': existing_nodes, + 'previous_episodes': [ + { + 'content': ep.content, + 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v1(context)) + new_nodes_data = llm_response.get('new_nodes', []) + logger.info(f'Extracted new nodes: {new_nodes_data}') + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + # Check if the node already exists + if not any(existing_node['name'] == node_data['name'] for existing_node in existing_nodes): + new_node = EntityNode( + name=node_data['name'], + labels=node_data['labels'], + summary=node_data['summary'], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') + else: + logger.info(f"Node {node_data['name']} already exists, skipping creation.") + + return new_nodes async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: - start = time() - - # Prepare context for LLM - context = { - "episode_content": episode.content, - "episode_timestamp": ( - episode.valid_at.isoformat() if episode.valid_at else None - ), - "previous_episodes": [ - { - "content": ep.content, - "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.v3(context) - ) - new_nodes_data = llm_response.get("new_nodes", []) - - end = time() - logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms") - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - new_node = EntityNode( - name=node_data["name"], - labels=node_data["labels"], - summary=node_data["summary"], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") - - return new_nodes + start = time() + + # Prepare context for LLM + context = { + 'episode_content': episode.content, + 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), + 'previous_episodes': [ + { + 'content': ep.content, + 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v3(context)) + new_nodes_data = llm_response.get('new_nodes', []) + + end = time() + logger.info(f'Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms') + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + new_node = EntityNode( + name=node_data['name'], + labels=node_data['labels'], + summary=node_data['summary'], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') + + return new_nodes async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - start = time() - - # build existing node map - node_map = {} - for node in existing_nodes: - node_map[node.name] = node - - # Temp hack - new_nodes_map = {} - for node in extracted_nodes: - new_nodes_map[node.name] = node - - # Prepare context for LLM - existing_nodes_context = [ - {"name": node.name, "summary": node.summary} for node in existing_nodes - ] - - extracted_nodes_context = [ - {"name": node.name, "summary": node.summary} for node in extracted_nodes - ] - - context = { - "existing_nodes": existing_nodes_context, - "extracted_nodes": extracted_nodes_context, - } - - llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.v2(context) - ) - - duplicate_data = llm_response.get("duplicates", []) - - end = time() - logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms") - - uuid_map = {} - for duplicate in duplicate_data: - uuid = new_nodes_map[duplicate["name"]].uuid - uuid_value = node_map[duplicate["duplicate_of"]].uuid - uuid_map[uuid] = uuid_value - - nodes = [] - brand_new_nodes = [] - for node in extracted_nodes: - if node.uuid in uuid_map: - existing_uuid = uuid_map[node.uuid] - # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, - # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? - # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) - existing_node = next( - (v for k, v in node_map.items() if v.uuid == existing_uuid), None - ) - nodes.append(existing_node) - continue - brand_new_nodes.append(node) - nodes.append(node) - - return nodes, uuid_map, brand_new_nodes + start = time() + + # build existing node map + node_map = {} + for node in existing_nodes: + node_map[node.name] = node + + # Temp hack + new_nodes_map = {} + for node in extracted_nodes: + new_nodes_map[node.name] = node + + # Prepare context for LLM + existing_nodes_context = [ + {'name': node.name, 'summary': node.summary} for node in existing_nodes + ] + + extracted_nodes_context = [ + {'name': node.name, 'summary': node.summary} for node in extracted_nodes + ] + + context = { + 'existing_nodes': existing_nodes_context, + 'extracted_nodes': extracted_nodes_context, + } + + llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v2(context)) + + duplicate_data = llm_response.get('duplicates', []) + + end = time() + logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms') + + uuid_map = {} + for duplicate in duplicate_data: + uuid = new_nodes_map[duplicate['name']].uuid + uuid_value = node_map[duplicate['duplicate_of']].uuid + uuid_map[uuid] = uuid_value + + nodes = [] + brand_new_nodes = [] + for node in extracted_nodes: + if node.uuid in uuid_map: + existing_uuid = uuid_map[node.uuid] + # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, + # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? + # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) + existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None) + nodes.append(existing_node) + continue + brand_new_nodes.append(node) + nodes.append(node) + + return nodes, uuid_map, brand_new_nodes async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - start = time() + start = time() - # build node map - node_map = {} - for node in nodes: - node_map[node.name] = node + # build node map + node_map = {} + for node in nodes: + node_map[node.name] = node - # Prepare context for LLM - nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes] + # Prepare context for LLM + nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes] - context = { - "nodes": nodes_context, - } + context = { + 'nodes': nodes_context, + } - llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.node_list(context) - ) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.node_list(context) + ) - nodes_data = llm_response.get("nodes", []) + nodes_data = llm_response.get('nodes', []) - end = time() - logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms") + end = time() + logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms') - # Get full node data - unique_nodes = [] - uuid_map: dict[str, str] = {} - for node_data in nodes_data: - node = node_map[node_data["names"][0]] - unique_nodes.append(node) + # Get full node data + unique_nodes = [] + uuid_map: dict[str, str] = {} + for node_data in nodes_data: + node = node_map[node_data['names'][0]] + unique_nodes.append(node) - for name in node_data["names"][1:]: - uuid = node_map[name].uuid - uuid_value = node_map[node_data["names"][0]].uuid - uuid_map[uuid] = uuid_value + for name in node_data['names'][1:]: + uuid = node_map[name].uuid + uuid_value = node_map[node_data['names'][0]].uuid + uuid_map[uuid] = uuid_value - return unique_nodes, uuid_map + return unique_nodes, uuid_map diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index 87b2fd40..37e2d7cb 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -13,120 +13,110 @@ def extract_node_and_edge_triplets( - edges: list[EntityEdge], nodes: list[EntityNode] + edges: list[EntityEdge], nodes: list[EntityNode] ) -> list[NodeEdgeNodeTriplet]: - return [extract_node_edge_node_triplet(edge, nodes) for edge in edges] + return [extract_node_edge_node_triplet(edge, nodes) for edge in edges] def extract_node_edge_node_triplet( - edge: EntityEdge, nodes: list[EntityNode] + edge: EntityEdge, nodes: list[EntityNode] ) -> NodeEdgeNodeTriplet: - source_node = next( - (node for node in nodes if node.uuid == edge.source_node_uuid), None - ) - target_node = next( - (node for node in nodes if node.uuid == edge.target_node_uuid), None - ) - return (source_node, edge, target_node) + source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) + target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) + return (source_node, edge, target_node) def prepare_edges_for_invalidation( - existing_edges: list[EntityEdge], - new_edges: list[EntityEdge], - nodes: list[EntityNode], + existing_edges: list[EntityEdge], + new_edges: list[EntityEdge], + nodes: list[EntityNode], ) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]: - existing_edges_pending_invalidation = [] # TODO: this is not yet used? - new_edges_with_nodes = [] # TODO: this is not yet used? + existing_edges_pending_invalidation = [] # TODO: this is not yet used? + new_edges_with_nodes = [] # TODO: this is not yet used? - existing_edges_pending_invalidation = [] - new_edges_with_nodes = [] + existing_edges_pending_invalidation = [] + new_edges_with_nodes = [] - for edge_list, result_list in [ - (existing_edges, existing_edges_pending_invalidation), - (new_edges, new_edges_with_nodes), - ]: - for edge in edge_list: - source_node = next( - (node for node in nodes if node.uuid == edge.source_node_uuid), None - ) - target_node = next( - (node for node in nodes if node.uuid == edge.target_node_uuid), None - ) + for edge_list, result_list in [ + (existing_edges, existing_edges_pending_invalidation), + (new_edges, new_edges_with_nodes), + ]: + for edge in edge_list: + source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) + target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) - if source_node and target_node: - result_list.append((source_node, edge, target_node)) + if source_node and target_node: + result_list.append((source_node, edge, target_node)) - return existing_edges_pending_invalidation, new_edges_with_nodes + return existing_edges_pending_invalidation, new_edges_with_nodes async def invalidate_edges( - llm_client: LLMClient, - existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityEdge]: - invalidated_edges = [] # TODO: this is not yet used? + invalidated_edges = [] # TODO: this is not yet used? - context = prepare_invalidation_context( - existing_edges_pending_invalidation, - new_edges, - current_episode, - previous_episodes, - ) - logger.info(prompt_library.invalidate_edges.v1(context)) - llm_response = await llm_client.generate_response( - prompt_library.invalidate_edges.v1(context) - ) - logger.info(f"invalidate_edges LLM response: {llm_response}") + context = prepare_invalidation_context( + existing_edges_pending_invalidation, + new_edges, + current_episode, + previous_episodes, + ) + logger.info(prompt_library.invalidate_edges.v1(context)) + llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context)) + logger.info(f'invalidate_edges LLM response: {llm_response}') - edges_to_invalidate = llm_response.get("invalidated_edges", []) - invalidated_edges = process_edge_invalidation_llm_response( - edges_to_invalidate, existing_edges_pending_invalidation - ) + edges_to_invalidate = llm_response.get('invalidated_edges', []) + invalidated_edges = process_edge_invalidation_llm_response( + edges_to_invalidate, existing_edges_pending_invalidation + ) - return invalidated_edges + return invalidated_edges def prepare_invalidation_context( - existing_edges: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + existing_edges: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> dict: - return { - "existing_edges": [ - f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" - for source_node, edge, target_node in sorted( - existing_edges, key=lambda x: x[1].created_at, reverse=True - ) - ], - "new_edges": [ - f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" - for source_node, edge, target_node in sorted( - new_edges, key=lambda x: x[1].created_at, reverse=True - ) - ], - "current_episode": current_episode.content, - "previous_episodes": [episode.content for episode in previous_episodes], - } + return { + 'existing_edges': [ + f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' + for source_node, edge, target_node in sorted( + existing_edges, key=lambda x: x[1].created_at, reverse=True + ) + ], + 'new_edges': [ + f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' + for source_node, edge, target_node in sorted( + new_edges, key=lambda x: x[1].created_at, reverse=True + ) + ], + 'current_episode': current_episode.content, + 'previous_episodes': [episode.content for episode in previous_episodes], + } def process_edge_invalidation_llm_response( - edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] + edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] ) -> List[EntityEdge]: - invalidated_edges = [] - for edge_to_invalidate in edges_to_invalidate: - edge_uuid = edge_to_invalidate["edge_uuid"] - edge_to_update = next( - (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), - None, - ) - if edge_to_update: - edge_to_update.expired_at = datetime.now() - edge_to_update.fact = edge_to_invalidate["fact"] - invalidated_edges.append(edge_to_update) - logger.info( - f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" - ) - return invalidated_edges + invalidated_edges = [] + for edge_to_invalidate in edges_to_invalidate: + edge_uuid = edge_to_invalidate['edge_uuid'] + edge_to_update = next( + (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), + None, + ) + if edge_to_update: + edge_to_update.expired_at = datetime.now() + edge_to_update.fact = edge_to_invalidate['fact'] + invalidated_edges.append(edge_to_update) + logger.info( + f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" + ) + return invalidated_edges diff --git a/core/utils/search/search_utils.py b/core/utils/search/search_utils.py index dcd36281..e34a3314 100644 --- a/core/utils/search/search_utils.py +++ b/core/utils/search/search_utils.py @@ -15,8 +15,8 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): - records, _, _ = await driver.execute_query( - """ + records, _, _ = await driver.execute_query( + """ MATCH (n WHERE n.uuid in $node_ids)-[r]->(m) RETURN n.uuid AS source_node_uuid, @@ -36,39 +36,39 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): r.invalid_at AS invalid_at """, - node_ids=node_ids, - ) - - context = {} - - for record in records: - n_uuid = record["source_node_uuid"] - if n_uuid in context: - context[n_uuid]["facts"].append(record["fact"]) - else: - context[n_uuid] = { - "name": record["source_name"], - "summary": record["source_summary"], - "facts": [record["fact"]], - } - - m_uuid = record["target_node_uuid"] - if m_uuid not in context: - context[m_uuid] = { - "name": record["target_name"], - "summary": record["target_summary"], - "facts": [], - } - logger.info(f"bfs search returned context: {context}") - return context + node_ids=node_ids, + ) + + context = {} + + for record in records: + n_uuid = record['source_node_uuid'] + if n_uuid in context: + context[n_uuid]['facts'].append(record['fact']) + else: + context[n_uuid] = { + 'name': record['source_name'], + 'summary': record['source_summary'], + 'facts': [record['fact']], + } + + m_uuid = record['target_node_uuid'] + if m_uuid not in context: + context[m_uuid] = { + 'name': record['target_name'], + 'summary': record['target_summary'], + 'facts': [], + } + logger.info(f'bfs search returned context: {context}') + return context async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: - # vector similarity search over embedded facts - records, _, _ = await driver.execute_query( - """ + # vector similarity search over embedded facts + records, _, _ = await driver.execute_query( + """ CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector) YIELD relationship AS r, score MATCH (n)-[r:RELATES_TO]->(m) @@ -86,38 +86,38 @@ async def edge_similarity_search( r.invalid_at AS invalid_at ORDER BY score DESC LIMIT $limit """, - search_vector=search_vector, - limit=limit, - ) + search_vector=search_vector, + limit=limit, + ) - edges: list[EntityEdge] = [] + edges: list[EntityEdge] = [] - for record in records: - edge = EntityEdge( - uuid=record["uuid"], - source_node_uuid=record["source_node_uuid"], - target_node_uuid=record["target_node_uuid"], - fact=record["fact"], - name=record["name"], - episodes=record["episodes"], - fact_embedding=record["fact_embedding"], - created_at=safely_parse_db_date(record["created_at"]), - expired_at=safely_parse_db_date(record["expired_at"]), - valid_at=safely_parse_db_date(record["valid_at"]), - invalid_At=safely_parse_db_date(record["invalid_at"]), - ) + for record in records: + edge = EntityEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + fact=record['fact'], + name=record['name'], + episodes=record['episodes'], + fact_embedding=record['fact_embedding'], + created_at=safely_parse_db_date(record['created_at']), + expired_at=safely_parse_db_date(record['expired_at']), + valid_at=safely_parse_db_date(record['valid_at']), + invalid_At=safely_parse_db_date(record['invalid_at']), + ) - edges.append(edge) + edges.append(edge) - return edges + return edges async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: - # vector similarity search over entity names - records, _, _ = await driver.execute_query( - """ + # vector similarity search over entity names + records, _, _ = await driver.execute_query( + """ CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score RETURN @@ -127,32 +127,32 @@ async def entity_similarity_search( n.summary AS summary ORDER BY score DESC """, - search_vector=search_vector, - limit=limit, - ) - nodes: list[EntityNode] = [] + search_vector=search_vector, + limit=limit, + ) + nodes: list[EntityNode] = [] - for record in records: - nodes.append( - EntityNode( - uuid=record["uuid"], - name=record["name"], - labels=[], - created_at=safely_parse_db_date(record["created_at"]), - summary=record["summary"], - ) - ) + for record in records: + nodes.append( + EntityNode( + uuid=record['uuid'], + name=record['name'], + labels=[], + created_at=safely_parse_db_date(record['created_at']), + summary=record['summary'], + ) + ) - return nodes + return nodes async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: - # BM25 search to get top nodes - fuzzy_query = query + "~" - records, _, _ = await driver.execute_query( - """ + # BM25 search to get top nodes + fuzzy_query = query + '~' + records, _, _ = await driver.execute_query( + """ CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score RETURN node.uuid As uuid, @@ -162,33 +162,33 @@ async def entity_fulltext_search( ORDER BY score DESC LIMIT $limit """, - query=fuzzy_query, - limit=limit, - ) - nodes: list[EntityNode] = [] + query=fuzzy_query, + limit=limit, + ) + nodes: list[EntityNode] = [] - for record in records: - nodes.append( - EntityNode( - uuid=record["uuid"], - name=record["name"], - labels=[], - created_at=safely_parse_db_date(record["created_at"]), - summary=record["summary"], - ) - ) + for record in records: + nodes.append( + EntityNode( + uuid=record['uuid'], + name=record['name'], + labels=[], + created_at=safely_parse_db_date(record['created_at']), + summary=record['summary'], + ) + ) - return nodes + return nodes async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: - # fulltext search over facts - fuzzy_query = query + "~" + # fulltext search over facts + fuzzy_query = query + '~' - records, _, _ = await driver.execute_query( - """ + records, _, _ = await driver.execute_query( + """ CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS r, score MATCH (n:Entity)-[r]->(m:Entity) @@ -206,91 +206,87 @@ async def edge_fulltext_search( r.invalid_at AS invalid_at ORDER BY score DESC LIMIT $limit """, - query=fuzzy_query, - limit=limit, - ) + query=fuzzy_query, + limit=limit, + ) - edges: list[EntityEdge] = [] + edges: list[EntityEdge] = [] - for record in records: - edge = EntityEdge( - uuid=record["uuid"], - source_node_uuid=record["source_node_uuid"], - target_node_uuid=record["target_node_uuid"], - fact=record["fact"], - name=record["name"], - episodes=record["episodes"], - fact_embedding=record["fact_embedding"], - created_at=safely_parse_db_date(record["created_at"]), - expired_at=safely_parse_db_date(record["expired_at"]), - valid_at=safely_parse_db_date(record["valid_at"]), - invalid_At=safely_parse_db_date(record["invalid_at"]), - ) + for record in records: + edge = EntityEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + fact=record['fact'], + name=record['name'], + episodes=record['episodes'], + fact_embedding=record['fact_embedding'], + created_at=safely_parse_db_date(record['created_at']), + expired_at=safely_parse_db_date(record['expired_at']), + valid_at=safely_parse_db_date(record['valid_at']), + invalid_At=safely_parse_db_date(record['invalid_at']), + ) - edges.append(edge) + edges.append(edge) - return edges + return edges def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime: - if date_str: - return datetime.fromisoformat(date_str.iso_format()) - return None + if date_str: + return datetime.fromisoformat(date_str.iso_format()) + return None async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: - start = time() - relevant_nodes: list[EntityNode] = [] - relevant_node_uuids = set() + start = time() + relevant_nodes: list[EntityNode] = [] + relevant_node_uuids = set() - results = await asyncio.gather( - *[entity_fulltext_search(node.name, driver) for node in nodes], - *[entity_similarity_search(node.name_embedding, driver) for node in nodes], - ) + results = await asyncio.gather( + *[entity_fulltext_search(node.name, driver) for node in nodes], + *[entity_similarity_search(node.name_embedding, driver) for node in nodes], + ) - for result in results: - for node in result: - if node.uuid in relevant_node_uuids: - continue + for result in results: + for node in result: + if node.uuid in relevant_node_uuids: + continue - relevant_node_uuids.add(node.uuid) - relevant_nodes.append(node) + relevant_node_uuids.add(node.uuid) + relevant_nodes.append(node) - end = time() - logger.info( - f"Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms" - ) + end = time() + logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms') - return relevant_nodes + return relevant_nodes async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, + edges: list[EntityEdge], + driver: AsyncDriver, ) -> list[EntityEdge]: - start = time() - relevant_edges: list[EntityEdge] = [] - relevant_edge_uuids = set() + start = time() + relevant_edges: list[EntityEdge] = [] + relevant_edge_uuids = set() - results = await asyncio.gather( - *[edge_similarity_search(edge.fact_embedding, driver) for edge in edges], - *[edge_fulltext_search(edge.fact, driver) for edge in edges], - ) + results = await asyncio.gather( + *[edge_similarity_search(edge.fact_embedding, driver) for edge in edges], + *[edge_fulltext_search(edge.fact, driver) for edge in edges], + ) - for result in results: - for edge in result: - if edge.uuid in relevant_edge_uuids: - continue + for result in results: + for edge in result: + if edge.uuid in relevant_edge_uuids: + continue - relevant_edge_uuids.add(edge.uuid) - relevant_edges.append(edge) + relevant_edge_uuids.add(edge.uuid) + relevant_edges.append(edge) - end = time() - logger.info( - f"Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms" - ) + end = time() + logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms') - return relevant_edges + return relevant_edges diff --git a/runner.py b/runner.py index 5de65d0e..ce150ffb 100644 --- a/runner.py +++ b/runner.py @@ -11,75 +11,73 @@ load_dotenv() -neo4j_uri = os.environ.get("NEO4J_URI") or "bolt://localhost:7687" -neo4j_user = os.environ.get("NEO4J_USER") or "neo4j" -neo4j_password = os.environ.get("NEO4J_PASSWORD") or "password" +neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' +neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' +neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) - # Create formatter - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - # Add formatter to console handler - console_handler.setFormatter(formatter) + # Add formatter to console handler + console_handler.setFormatter(formatter) - # Add console handler to logger - logger.addHandler(console_handler) + # Add console handler to logger + logger.addHandler(console_handler) - return logger + return logger async def main(): - setup_logging() - client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - await clear_data(client.driver) - - # await client.build_indices() - await client.add_episode( - name="Message 3", - episode_body="Jane: I am married to Paul", - source_description="WhatsApp Message", - reference_time=datetime.now(), - ) - await client.add_episode( - name="Message 4", - episode_body="Paul: I have divorced Jane", - source_description="WhatsApp Message", - reference_time=datetime.now(), - ) - await client.add_episode( - name="Message 5", - episode_body="Jane: I miss Paul", - source_description="WhatsApp Message", - reference_time=datetime.now(), - ) - await client.add_episode( - name="Message 6", - episode_body="Jane: I dont miss Paul anymore, I hate him", - source_description="WhatsApp Message", - reference_time=datetime.now(), - ) - - # await client.add_episode( - # name="Message 3", - # episode_body="Assistant: The best type of apples available are Fuji apples", - # source_description="WhatsApp Message", - # ) - # await client.add_episode( - # name="Message 4", - # episode_body="Paul: Oh, I actually hate those", - # source_description="WhatsApp Message", - # ) + setup_logging() + client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + await clear_data(client.driver) + + # await client.build_indices() + await client.add_episode( + name='Message 3', + episode_body='Jane: I am married to Paul', + source_description='WhatsApp Message', + reference_time=datetime.now(), + ) + await client.add_episode( + name='Message 4', + episode_body='Paul: I have divorced Jane', + source_description='WhatsApp Message', + reference_time=datetime.now(), + ) + await client.add_episode( + name='Message 5', + episode_body='Jane: I miss Paul', + source_description='WhatsApp Message', + reference_time=datetime.now(), + ) + await client.add_episode( + name='Message 6', + episode_body='Jane: I dont miss Paul anymore, I hate him', + source_description='WhatsApp Message', + reference_time=datetime.now(), + ) + + # await client.add_episode( + # name="Message 3", + # episode_body="Assistant: The best type of apples available are Fuji apples", + # source_description="WhatsApp Message", + # ) + # await client.add_episode( + # name="Message 4", + # episode_body="Paul: Oh, I actually hate those", + # source_description="WhatsApp Message", + # ) asyncio.run(main()) From 6a97e9f1a3dca3155835271c4f5d3d9160913f19 Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 18:05:48 -0400 Subject: [PATCH 07/10] fix: Duplication --- core/graphiti.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index d52ef5b7..bdfaafbd 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -180,7 +180,14 @@ async def add_episode( edge_touched_node_uuids.append(edge.source_node_uuid) edge_touched_node_uuids.append(edge.target_node_uuid) - entity_edges.extend(invalidated_edges) + edges_to_save = invalidated_edges + + # There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one + for deduped_edge in deduped_edges: + if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]: + edges_to_save.append(deduped_edge) + + entity_edges.extend(edges_to_save) edge_touched_node_uuids = list(set(edge_touched_node_uuids)) involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids] @@ -190,7 +197,6 @@ async def add_episode( logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') - entity_edges.extend(deduped_edges) episodic_edges.extend( build_episodic_edges( @@ -203,12 +209,6 @@ async def add_episode( # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built logger.info(f'Built episodic edges: {episodic_edges}') - # invalidated_edges = await self.invalidate_edges( - # episode, new_nodes, new_edges, relevant_schema, previous_episodes - # ) - - # edges.extend(invalidated_edges) - # Future optimization would be using batch operations to save nodes and edges await episode.save(self.driver) await asyncio.gather(*[node.save(self.driver) for node in nodes]) From d2eac942fb2c469af04366911614bc3299e668e7 Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 18:17:59 -0400 Subject: [PATCH 08/10] chore: Fix unit tests for temporal invalidation --- .../maintenance/test_temporal_operations.py | 79 +++++++++++++++++-- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py index 9fbcf2af..7a3b94cc 100644 --- a/tests/utils/maintenance/test_temporal_operations.py +++ b/tests/utils/maintenance/test_temporal_operations.py @@ -3,7 +3,7 @@ import pytest from core.edges import EntityEdge -from core.nodes import EntityNode +from core.nodes import EntityNode, EpisodicNode from core.utils.maintenance.temporal_operations import ( prepare_edges_for_invalidation, prepare_invalidation_context, @@ -114,7 +114,6 @@ def test_prepare_edges_for_invalidation_missing_nodes(): def test_prepare_invalidation_context(): - # Create test data now = datetime.now() # Create nodes @@ -148,15 +147,49 @@ def test_prepare_invalidation_context(): existing_edges = [existing_edge] new_edges = [new_edge] + # Create a current episode and previous episodes + current_episode = EpisodicNode( + name='Current Episode', + content='This is the current episode content.', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + previous_episodes = [ + EpisodicNode( + name='Previous Episode 1', + content='This is the content of previous episode 1.', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode 1 for unit testing', + ), + EpisodicNode( + name='Previous Episode 2', + content='This is the content of previous episode 2.', + created_at=now - timedelta(days=2), + valid_at=now - timedelta(days=2), + source='test', + source_description='Test previous episode 2 for unit testing', + ), + ] + # Call the function - result = prepare_invalidation_context(existing_edges, new_edges) + result = prepare_invalidation_context( + existing_edges, new_edges, current_episode, previous_episodes + ) # Assert the result assert isinstance(result, dict) assert 'existing_edges' in result assert 'new_edges' in result + assert 'current_episode' in result + assert 'previous_episodes' in result assert len(result['existing_edges']) == 1 assert len(result['new_edges']) == 1 + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 2 # Check the format of the existing edge existing_edge_str = result['existing_edges'][0] @@ -176,12 +209,25 @@ def test_prepare_invalidation_context(): def test_prepare_invalidation_context_empty_input(): - result = prepare_invalidation_context([], []) + now = datetime.now() + current_episode = EpisodicNode( + name='Current Episode', + content='Empty episode', + created_at=now, + valid_at=now, + source='test', + source_description='Test empty episode for unit testing', + ) + result = prepare_invalidation_context([], [], current_episode, []) assert isinstance(result, dict) assert 'existing_edges' in result assert 'new_edges' in result + assert 'current_episode' in result + assert 'previous_episodes' in result assert len(result['existing_edges']) == 0 assert len(result['new_edges']) == 0 + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 0 def test_prepare_invalidation_context_sorting(): @@ -215,13 +261,36 @@ def test_prepare_invalidation_context_sorting(): # Prepare test input existing_edges = [edge_with_nodes1, edge_with_nodes2] + # Create a current episode and previous episodes + current_episode = EpisodicNode( + name='Current Episode', + content='This is the current episode content.', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + previous_episodes = [ + EpisodicNode( + name='Previous Episode', + content='This is the content of a previous episode.', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode for unit testing', + ), + ] + # Call the function - result = prepare_invalidation_context(existing_edges, []) + result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes) # Assert the result assert len(result['existing_edges']) == 2 assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first assert edge1.uuid in result['existing_edges'][1] # The older edge should be second + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 1 + assert result['previous_episodes'][0] == previous_episodes[0].content # Run the tests From 868215cc1290b92071a1569fd74200f30aa0c4d4 Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 18:25:40 -0400 Subject: [PATCH 09/10] attempt to fix unit tests --- .../maintenance/graph_data_operations.py | 2 +- .../test_temporal_operations_int.py | 39 +++++++++++++++---- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py index 9ee91a64..919c92ac 100644 --- a/core/utils/maintenance/graph_data_operations.py +++ b/core/utils/maintenance/graph_data_operations.py @@ -1,7 +1,7 @@ import asyncio import logging from datetime import datetime, timezone -from typing import LiteralString +from typing_extensions import LiteralString from neo4j import AsyncDriver diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 6ea46d4d..37baf0cc 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -6,7 +6,7 @@ from core.edges import EntityEdge from core.llm_client import LLMConfig, OpenAIClient -from core.nodes import EntityNode +from core.nodes import EntityNode, EpisodicNode from core.utils.maintenance.temporal_operations import ( invalidate_edges, ) @@ -24,7 +24,6 @@ def setup_llm_client(): ) -# Helper function to create test data def create_test_data(): now = datetime.now() @@ -53,15 +52,39 @@ def create_test_data(): existing_edge = (node1, edge1, node2) new_edge = (node1, edge2, node2) - return existing_edge, new_edge + # Create current episode + current_episode = EpisodicNode( + name='Current Episode', + content='Alice now dislikes Bob', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + + # Create previous episodes + previous_episodes = [ + EpisodicNode( + name='Previous Episode', + content='Alice liked Bob', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode for unit testing', + ) + ] + + return existing_edge, new_edge, current_episode, previous_episodes @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges(): - existing_edge, new_edge = create_test_data() + existing_edge, new_edge, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [new_edge]) + invalidated_edges = await invalidate_edges( + setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes + ) assert len(invalidated_edges) == 1 assert invalidated_edges[0].uuid == existing_edge[1].uuid @@ -71,9 +94,11 @@ async def test_invalidate_edges(): @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_no_invalidation(): - existing_edge, _ = create_test_data() + existing_edge, _, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], []) + invalidated_edges = await invalidate_edges( + setup_llm_client(), [existing_edge], [], current_episode, previous_episodes + ) assert len(invalidated_edges) == 0 From dd047253d7a636cabbf7b96a9497d3a60b88a25f Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 18:25:56 -0400 Subject: [PATCH 10/10] fix: format --- core/utils/maintenance/graph_data_operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py index 919c92ac..368ddb5f 100644 --- a/core/utils/maintenance/graph_data_operations.py +++ b/core/utils/maintenance/graph_data_operations.py @@ -1,9 +1,9 @@ import asyncio import logging from datetime import datetime, timezone -from typing_extensions import LiteralString from neo4j import AsyncDriver +from typing_extensions import LiteralString from core.nodes import EpisodicNode