From 610005f7fbef10ecd064ca352a5941dd911fdac7 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 09:51:30 -0400 Subject: [PATCH 01/12] set and retrieve group ids --- graphiti_core/edges.py | 71 +++++----- graphiti_core/graphiti.py | 25 ++-- graphiti_core/nodes.py | 76 +++++----- graphiti_core/search/search_utils.py | 133 +++--------------- .../utils/maintenance/edge_operations.py | 13 +- .../utils/maintenance/node_operations.py | 2 + 6 files changed, 114 insertions(+), 206 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 645a3b32..8856fec7 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -18,6 +18,7 @@ from abc import ABC, abstractmethod from datetime import datetime from time import time +from typing import Any from uuid import uuid4 from neo4j import AsyncDriver @@ -32,6 +33,7 @@ class Edge(BaseModel, ABC): uuid: str = Field(default_factory=lambda: uuid4().hex) + group_id: str = Field(description='partition of the graph') source_node_uuid: str target_node_uuid: str created_at: datetime @@ -61,11 +63,12 @@ async def save(self, driver: AsyncDriver): MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (node:Entity {uuid: $entity_uuid}) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) - SET r = {uuid: $uuid, created_at: $created_at} + SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at} RETURN r.uuid AS uuid""", episode_uuid=self.source_node_uuid, entity_uuid=self.target_node_uuid, uuid=self.uuid, + group_id=self.group_id, created_at=self.created_at, ) @@ -92,7 +95,8 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): """ MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) RETURN - e.uuid As uuid, + e.uuid As uuid, + e.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, e.created_at AS created_at @@ -100,17 +104,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): uuid=uuid, ) - edges: list[EpisodicEdge] = [] - - for record in records: - edges.append( - EpisodicEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - created_at=record['created_at'].to_native(), - ) - ) + edges = [get_episodic_edge_from_record(record) for record in records] logger.info(f'Found Edge: {uuid}') @@ -153,7 +147,7 @@ async def save(self, driver: AsyncDriver): MATCH (source:Entity {uuid: $source_uuid}) MATCH (target:Entity {uuid: $target_uuid}) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) - SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding, + SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding, episodes: $episodes, created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at} RETURN r.uuid AS uuid""", @@ -161,6 +155,7 @@ async def save(self, driver: AsyncDriver): target_uuid=self.target_node_uuid, uuid=self.uuid, name=self.name, + group_id=self.group_id, fact=self.fact, fact_embedding=self.fact_embedding, episodes=self.episodes, @@ -198,6 +193,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): m.uuid AS target_node_uuid, e.created_at AS created_at, e.name AS name, + e.group_id AS group_id, e.fact AS fact, e.fact_embedding AS fact_embedding, e.episodes AS episodes, @@ -208,25 +204,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): uuid=uuid, ) - edges: list[EntityEdge] = [] - - for record in records: - edges.append( - EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) - ) + edges = [get_entity_edge_from_record(record) for record in records] logger.info(f'Found Edge: {uuid}') return edges[0] + + +# Edge helpers +def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: + return EpisodicEdge( + uuid=record['uuid'], + group_id=record['group_id'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + created_at=record['created_at'].to_native(), + ) + + +def get_entity_edge_from_record(record: Any) -> EntityEdge: + return EntityEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + fact=record['fact'], + name=record['name'], + group_id=record['group_id'], + episodes=record['episodes'], + fact_embedding=record['fact_embedding'], + created_at=record['created_at'].to_native(), + expired_at=parse_db_date(record['expired_at']), + valid_at=parse_db_date(record['valid_at']), + invalid_at=parse_db_date(record['invalid_at']), + ) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 6684fcc0..b0f866f5 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -211,8 +211,7 @@ async def add_episode( source_description: str, reference_time: datetime, source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + group_id: str | None = None, ): """ Process an episode and update the graph. @@ -232,10 +231,8 @@ async def add_episode( The reference time for the episode. source : EpisodeType, optional The type of the episode. Defaults to EpisodeType.message. - success_callback : Callable | None, optional - A callback function to be called upon successful processing. - error_callback : Callable | None, optional - A callback function to be called if an error occurs during processing. + group_id : str | None + An id for the graph partition the episode is a part of. Returns ------- @@ -269,6 +266,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): previous_episodes = await self.retrieve_episodes(reference_time, last_n=3) episode = EpisodicNode( name=name, + group_id=group_id, labels=[], source=source, content=episode_body, @@ -279,7 +277,9 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Extract entities as nodes - extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) + extracted_nodes = await extract_nodes( + self.llm_client, episode, previous_episodes, group_id + ) logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') # Calculate Embeddings @@ -389,9 +389,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}') episodic_edges: list[EpisodicEdge] = build_episodic_edges( - mentioned_nodes, - episode, - now, + mentioned_nodes, episode, now, group_id ) logger.info(f'Built episodic edges: {episodic_edges}') @@ -405,13 +403,8 @@ async def add_episode_endpoint(episode_data: EpisodeData): end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') - if success_callback: - await success_callback(episode) except Exception as e: - if error_callback: - await error_callback(episode, e) - else: - raise e + raise e async def add_episode_bulk( self, diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index f30d001d..60b52f62 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -19,6 +19,7 @@ from datetime import datetime from enum import Enum from time import time +from typing import Any from uuid import uuid4 from neo4j import AsyncDriver @@ -69,6 +70,7 @@ def from_str(episode_type: str): class Node(BaseModel, ABC): uuid: str = Field(default_factory=lambda: uuid4().hex) name: str = Field(description='name of the node') + group_id: str = Field(description='partition of the graph') labels: list[str] = Field(default_factory=list) created_at: datetime = Field(default_factory=lambda: datetime.now()) @@ -106,11 +108,12 @@ async def save(self, driver: AsyncDriver): result = await driver.execute_query( """ MERGE (n:Episodic {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content, + SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid""", uuid=self.uuid, name=self.name, + group_id=self.group_id, source_description=self.source_description, content=self.content, entity_edges=self.entity_edges, @@ -141,29 +144,19 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): records, _, _ = await driver.execute_query( """ MATCH (e:Episodic {uuid: $uuid}) - RETURN e.content as content, - e.created_at as created_at, - e.valid_at as valid_at, - e.uuid as uuid, - e.name as name, - e.source_description as source_description, - e.source as source + RETURN e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id + e.source_description AS source_description, + e.source AS source """, uuid=uuid, ) - episodes = [ - EpisodicNode( - content=record['content'], - created_at=record['created_at'].to_native().timestamp(), - valid_at=(record['valid_at'].to_native()), - uuid=record['uuid'], - source=EpisodeType.from_str(record['source']), - name=record['name'], - source_description=record['source_description'], - ) - for record in records - ] + episodes = [get_episodic_node_from_record(record) for record in records] logger.info(f'Found Node: {uuid}') @@ -174,10 +167,6 @@ class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - async def update_summary(self, driver: AsyncDriver): ... - - async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... - async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): start = time() text = self.name.replace('\n', ' ') @@ -192,10 +181,11 @@ async def save(self, driver: AsyncDriver): result = await driver.execute_query( """ MERGE (n:Entity {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at} + SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at} RETURN n.uuid AS uuid""", uuid=self.uuid, name=self.name, + group_id=self.group_id, summary=self.summary, name_embedding=self.name_embedding, created_at=self.created_at, @@ -227,25 +217,14 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): n.uuid As uuid, n.name AS name, n.name_embedding AS name_embedding, + n.group_id AS group_id n.created_at AS created_at, n.summary AS summary """, uuid=uuid, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] logger.info(f'Found Node: {uuid}') @@ -253,3 +232,24 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): # Node helpers +def get_episodic_node_from_record(record: Any) -> EpisodicNode: + return EpisodicNode( + content=record['content'], + created_at=record['created_at'].to_native().timestamp(), + valid_at=(record['valid_at'].to_native()), + uuid=record['uuid'], + source=EpisodeType.from_str(record['source']), + name=record['name'], + source_description=record['source_description'], + ) + + +def get_entity_node_from_record(record: Any) -> EntityNode: + return EntityNode( + uuid=record['uuid'], + name=record['name'], + name_embedding=record['name_embedding'], + labels=['Entity'], + created_at=record['created_at'].to_native(), + summary=record['summary'], + ) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 3f7987b1..83800126 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -7,9 +7,9 @@ from neo4j import AsyncDriver, Query -from graphiti_core.edges import EntityEdge +from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.helpers import parse_db_date -from graphiti_core.nodes import EntityNode, EpisodicNode +from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record logger = logging.getLogger(__name__) @@ -48,55 +48,6 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) return nodes -async def bfs(node_ids: list[str], driver: AsyncDriver): - records, _, _ = await driver.execute_query( - """ - MATCH (n WHERE n.uuid in $node_ids)-[r]->(m) - RETURN DISTINCT - n.uuid AS source_node_uuid, - n.name AS source_name, - n.summary AS source_summary, - m.uuid AS target_node_uuid, - m.name AS target_name, - m.summary AS target_summary, - r.uuid AS uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - - """, - node_ids=node_ids, - ) - - context: dict[str, Any] = {} - - for record in records: - n_uuid = record['source_node_uuid'] - if n_uuid in context: - context[n_uuid]['facts'].append(record['fact']) - else: - context[n_uuid] = { - 'name': record['source_name'], - 'summary': record['source_summary'], - 'facts': [record['fact']], - } - - m_uuid = record['target_node_uuid'] - if m_uuid not in context: - context[m_uuid] = { - 'name': record['target_name'], - 'summary': record['target_summary'], - 'facts': [], - } - logger.info(f'bfs search returned context: {context}') - return context - - async def edge_similarity_search( driver: AsyncDriver, search_vector: list[float], @@ -111,6 +62,7 @@ async def edge_similarity_search( MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -131,6 +83,7 @@ async def edge_similarity_search( MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -150,6 +103,7 @@ async def edge_similarity_search( MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -169,6 +123,7 @@ async def edge_similarity_search( MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -190,24 +145,7 @@ async def edge_similarity_search( limit=limit, ) - edges: list[EntityEdge] = [] - - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) - - edges.append(edge) + edges = [get_entity_edge_from_record(record) for record in records] return edges @@ -221,7 +159,8 @@ async def entity_similarity_search( CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score RETURN - n.uuid As uuid, + n.uuid As uuid, + n.group_id AS group_id, n.name AS name, n.name_embedding AS name_embedding, n.created_at AS created_at, @@ -231,19 +170,7 @@ async def entity_similarity_search( search_vector=search_vector, limit=limit, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -257,7 +184,8 @@ async def entity_fulltext_search( """ CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score RETURN - node.uuid AS uuid, + node.uuid AS uuid, + node.group_id AS group_id, node.name AS name, node.name_embedding AS name_embedding, node.created_at AS created_at, @@ -268,19 +196,7 @@ async def entity_fulltext_search( query=fuzzy_query, limit=limit, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -299,6 +215,7 @@ async def edge_fulltext_search( MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -319,6 +236,7 @@ async def edge_fulltext_search( MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -338,6 +256,7 @@ async def edge_fulltext_search( MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -357,6 +276,7 @@ async def edge_fulltext_search( MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) RETURN r.uuid AS uuid, + r.group_id AS group_id, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, r.created_at AS created_at, @@ -380,24 +300,7 @@ async def edge_fulltext_search( limit=limit, ) - edges: list[EntityEdge] = [] - - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) - - edges.append(edge) + edges = [get_entity_edge_from_record(record) for record in records] return edges diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 0d6aa9eb..d45718af 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -36,16 +36,17 @@ def build_episodic_edges( entity_nodes: List[EntityNode], episode: EpisodicNode, created_at: datetime, + group_id: str | None, ) -> List[EpisodicEdge]: - edges: List[EpisodicEdge] = [] - - for node in entity_nodes: - edge = EpisodicEdge( + edges: List[EpisodicEdge] = [ + EpisodicEdge( source_node_uuid=episode.uuid, target_node_uuid=node.uuid, created_at=created_at, + group_id=group_id, ) - edges.append(edge) + for node in entity_nodes + ] return edges @@ -55,6 +56,7 @@ async def extract_edges( episode: EpisodicNode, nodes: list[EntityNode], previous_episodes: list[EpisodicNode], + group_id: str | None, ) -> list[EntityEdge]: start = time() @@ -88,6 +90,7 @@ async def extract_edges( source_node_uuid=edge_data['source_node_uuid'], target_node_uuid=edge_data['target_node_uuid'], name=edge_data['relation_type'], + group_id=group_id, fact=edge_data['fact'], episodes=[episode.uuid], created_at=datetime.now(), diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 30673eef..6d857158 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -70,6 +70,7 @@ async def extract_nodes( llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode], + group_id: str | None, ) -> list[EntityNode]: start = time() extracted_node_data: list[dict[str, Any]] = [] @@ -85,6 +86,7 @@ async def extract_nodes( for node_data in extracted_node_data: new_node = EntityNode( name=node_data['name'], + group_id=group_id, labels=node_data['labels'], summary=node_data['summary'], created_at=datetime.now(), From 04f77b155f8b29a77b519205849d5bd4140e6f7c Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 10:38:04 -0400 Subject: [PATCH 02/12] update add episode with group id support --- graphiti_core/edges.py | 2 +- graphiti_core/graphiti.py | 20 ++++- graphiti_core/nodes.py | 2 +- graphiti_core/search/search.py | 7 +- graphiti_core/search/search_utils.py | 86 +++++++++++++------ .../maintenance/graph_data_operations.py | 25 ++++-- 6 files changed, 99 insertions(+), 43 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 8856fec7..1e60c944 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -33,7 +33,7 @@ class Edge(BaseModel, ABC): uuid: str = Field(default_factory=lambda: uuid4().hex) - group_id: str = Field(description='partition of the graph') + group_id: str | None = Field(description='partition of the graph') source_node_uuid: str target_node_uuid: str created_at: datetime diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index b0f866f5..ec4aaca6 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -178,6 +178,7 @@ async def retrieve_episodes( self, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -191,6 +192,8 @@ async def retrieve_episodes( The reference time to retrieve episodes before. last_n : int, optional The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN. + group_ids : list[str], optional + The group ids to return data from. Returns ------- @@ -202,7 +205,7 @@ async def retrieve_episodes( The actual retrieval is performed by the `retrieve_episodes` function from the `graphiti_core.utils` module. """ - return await retrieve_episodes(self.driver, reference_time, last_n) + return await retrieve_episodes(self.driver, reference_time, last_n, group_ids) async def add_episode( self, @@ -263,7 +266,9 @@ async def add_episode_endpoint(episode_data: EpisodeData): embedder = self.llm_client.get_embedder() now = datetime.now() - previous_episodes = await self.retrieve_episodes(reference_time, last_n=3) + previous_episodes = await self.retrieve_episodes( + reference_time, last_n=3, group_ids=[group_id] + ) episode = EpisodicNode( name=name, group_id=group_id, @@ -520,7 +525,13 @@ async def add_episode_bulk( except Exception as e: raise e - async def search(self, query: str, center_node_uuid: str | None = None, num_results=10): + async def search( + self, + query: str, + center_node_uuid: str | None = None, + group_ids: list[str] | None = None, + num_results=10, + ): """ Perform a hybrid search on the knowledge graph. @@ -533,6 +544,8 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu The search query string. center_node_uuid: str, optional Facts will be reranked based on proximity to this node + group_ids : list[str] | None, optional + The graph partitions to return data from. num_results : int, optional The maximum number of results to return. Defaults to 10. @@ -555,6 +568,7 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu num_episodes=0, num_edges=num_results, num_nodes=0, + group_ids=group_ids, search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity], reranker=reranker, ) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 60b52f62..cde7c87d 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -70,7 +70,7 @@ def from_str(episode_type: str): class Node(BaseModel, ABC): uuid: str = Field(default_factory=lambda: uuid4().hex) name: str = Field(description='name of the node') - group_id: str = Field(description='partition of the graph') + group_id: str | None = Field(description='partition of the graph') labels: list[str] = Field(default_factory=list) created_at: datetime = Field(default_factory=lambda: datetime.now()) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 172cd123..87836765 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -52,6 +52,7 @@ class SearchConfig(BaseModel): num_edges: int = Field(default=10) num_nodes: int = Field(default=10) num_episodes: int = EPISODE_WINDOW_LEN + group_ids: list[str] | None search_methods: list[SearchMethod] reranker: Reranker | None @@ -83,7 +84,9 @@ async def hybrid_search( nodes.extend(await get_mentioned_nodes(driver, episodes)) if SearchMethod.bm25 in config.search_methods: - text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges) + text_search = await edge_fulltext_search( + driver, query, None, None, config.group_ids, 2 * config.num_edges + ) search_results.append(text_search) if SearchMethod.cosine_similarity in config.search_methods: @@ -95,7 +98,7 @@ async def hybrid_search( ) similarity_search = await edge_similarity_search( - driver, search_vector, None, None, 2 * config.num_edges + driver, search_vector, None, None, config.group_ids, 2 * config.num_edges ) search_results.append(similarity_search) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 83800126..2a2b41bf 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -23,6 +23,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids RETURN DISTINCT n.uuid As uuid, + n.group_id AS group_id, n.name AS name, n.name_embedding AS name_embedding n.created_at AS created_at, @@ -31,19 +32,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) uuids=episode_uuids, ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - name_embedding=record['name_embedding'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -53,13 +42,16 @@ async def edge_similarity_search( search_vector: list[float], source_node_uuid: str | None, target_node_uuid: str | None, + group_ids: list[str | None] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: + group_ids = group_ids if group_ids is not None else [None] # vector similarity search over embedded facts query = Query(""" CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -81,6 +73,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -101,6 +94,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -121,6 +115,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -142,6 +137,7 @@ async def edge_similarity_search( search_vector=search_vector, source_uuid=source_node_uuid, target_uuid=target_node_uuid, + group_ids=group_ids, limit=limit, ) @@ -151,13 +147,19 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: + group_ids = group_ids if group_ids is not None else [None] + # vector similarity search over entity names records, _, _ = await driver.execute_query( """ CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score + MATCH (n WHERE n.group_id IN $group_ids) RETURN n.uuid As uuid, n.group_id AS group_id, @@ -168,6 +170,7 @@ async def entity_similarity_search( ORDER BY score DESC """, search_vector=search_vector, + group_ids=group_ids, limit=limit, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -176,24 +179,32 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: + group_ids = group_ids if group_ids is not None else [None] + # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' records, _, _ = await driver.execute_query( """ - CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score + CALL db.index.fulltext.queryNodes("name_and_summary", $query) + YIELD node AS n, score + MATCH (n WHERE n.group_id in $group_ids) RETURN - node.uuid AS uuid, - node.group_id AS group_id, - node.name AS name, - node.name_embedding AS name_embedding, - node.created_at AS created_at, - node.summary AS summary + n.uuid AS uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary ORDER BY score DESC LIMIT $limit """, query=fuzzy_query, + group_ids=group_ids, limit=limit, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -206,13 +217,17 @@ async def edge_fulltext_search( query: str, source_node_uuid: str | None, target_node_uuid: str | None, + group_ids: list[str | None] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: + group_ids = group_ids if group_ids is not None else [None] + # fulltext search over facts cypher_query = Query(""" CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -234,6 +249,7 @@ async def edge_fulltext_search( CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -253,7 +269,8 @@ async def edge_fulltext_search( cypher_query = Query(""" CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -273,7 +290,8 @@ async def edge_fulltext_search( cypher_query = Query(""" CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + WHERE r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -297,6 +315,7 @@ async def edge_fulltext_search( query=fuzzy_query, source_uuid=source_node_uuid, target_uuid=target_node_uuid, + group_ids=group_ids, limit=limit, ) @@ -309,6 +328,7 @@ async def hybrid_node_search( queries: list[str], embeddings: list[list[float]], driver: AsyncDriver, + group_ids: list[str | None] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ @@ -325,6 +345,8 @@ async def hybrid_node_search( A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed. driver : AsyncDriver The Neo4j driver instance for database operations. + group_ids : list[str] | None, optional + The list of group ids to retrieve nodes from. limit : int | None, optional The maximum number of results to return per search method. If None, a default limit will be applied. @@ -351,8 +373,8 @@ async def hybrid_node_search( results: list[list[EntityNode]] = list( await asyncio.gather( - *[entity_fulltext_search(q, driver, 2 * limit) for q in queries], - *[entity_similarity_search(e, driver, 2 * limit) for e in embeddings], + *[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries], + *[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings], ) ) @@ -403,6 +425,7 @@ async def get_relevant_nodes( [node.name for node in nodes], [node.name_embedding for node in nodes if node.name_embedding is not None], driver, + [node.group_id for node in nodes], ) return relevant_nodes @@ -421,13 +444,20 @@ async def get_relevant_edges( results = await asyncio.gather( *[ edge_similarity_search( - driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit + driver, + edge.fact_embedding, + source_node_uuid, + target_node_uuid, + [edge.group_id], + limit, ) for edge in edges if edge.fact_embedding is not None ], *[ - edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit) + edge_fulltext_search( + driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit + ) for edge in edges ], ) diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 38620a8d..84e9d68e 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -34,6 +34,10 @@ async def build_indices_and_constraints(driver: AsyncDriver): 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', + 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', + 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', + 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', @@ -86,6 +90,7 @@ async def retrieve_episodes( driver: AsyncDriver, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -96,25 +101,28 @@ async def retrieve_episodes( less than or equal to this reference_time will be retrieved. This allows for querying the graph's state at a specific point in time. last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time. + group_ids (list[str], optional): The list of group ids to return data from. Returns: list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. """ result = await driver.execute_query( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - RETURN e.content as content, - e.created_at as created_at, - e.valid_at as valid_at, - e.uuid as uuid, - e.name as name, - e.source_description as source_description, - e.source as source + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids + RETURN e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.group_id AS group_id + e.name AS name, + e.source_description AS source_description, + e.source AS source ORDER BY e.created_at DESC LIMIT $num_episodes """, reference_time=reference_time, num_episodes=last_n, + group_ids=group_ids, ) episodes = [ EpisodicNode( @@ -124,6 +132,7 @@ async def retrieve_episodes( ), valid_at=(record['valid_at'].to_native()), uuid=record['uuid'], + group_id=record['group_id'], source=EpisodeType.from_str(record['source']), name=record['name'], source_description=record['source_description'], From eaff76ed36f5e01b29c553e7732891574b3f2e89 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 10:52:45 -0400 Subject: [PATCH 03/12] add episode and search functional --- examples/podcast/podcast_runner.py | 1 + graphiti_core/graphiti.py | 4 +++- graphiti_core/nodes.py | 1 + graphiti_core/utils/maintenance/graph_data_operations.py | 2 +- tests/test_graphiti_int.py | 6 +++--- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index f100926e..ec144987 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -69,6 +69,7 @@ async def main(use_bulk: bool = True): episode_body=f'{message.speaker_name} ({message.role}): {message.content}', reference_time=message.actual_timestamp, source_description='Podcast Transcript', + group_id='1', ) return diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index ec4aaca6..475962d9 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -304,7 +304,9 @@ async def add_episode_endpoint(episode_data: EpisodeData): (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather( resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists), - extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes), + extract_edges( + self.llm_client, episode, extracted_nodes, previous_episodes, group_id + ), ) logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}') nodes.extend(mentioned_nodes) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index cde7c87d..1a6d6268 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -248,6 +248,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode: return EntityNode( uuid=record['uuid'], name=record['name'], + group_id=record['group_id'], name_embedding=record['name_embedding'], labels=['Entity'], created_at=record['created_at'].to_native(), diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 84e9d68e..febe7035 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -113,7 +113,7 @@ async def retrieve_episodes( e.created_at AS created_at, e.valid_at AS valid_at, e.uuid AS uuid, - e.group_id AS group_id + e.group_id AS group_id, e.name AS name, e.source_description AS source_description, e.source AS source diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 68c54a06..2c2ebc35 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -74,15 +74,15 @@ async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - edges = await graphiti.search('Freakenomics guest') + edges = await graphiti.search('Freakenomics guest', group_ids=['1']) logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges])) - edges = await graphiti.search('tania tetlow\n') + edges = await graphiti.search('tania tetlow', group_ids=['1']) logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges])) - edges = await graphiti.search('issues with higher ed') + edges = await graphiti.search('issues with higher ed', group_ids=['1']) logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) graphiti.close() From 9caf8c4b882bb66220ef4e882eff3a1ae2a24304 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 11:03:30 -0400 Subject: [PATCH 04/12] update bulk --- graphiti_core/graphiti.py | 29 +++++++++---------- graphiti_core/utils/bulk_utils.py | 12 ++++++-- .../utils/maintenance/edge_operations.py | 3 +- .../utils/maintenance/node_operations.py | 3 +- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 475962d9..952e570c 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -18,7 +18,6 @@ import logging from datetime import datetime from time import time -from typing import Callable from dotenv import load_dotenv from neo4j import AsyncGraphDatabase @@ -120,7 +119,7 @@ def close(self): Parameters ---------- - None + self Returns ------- @@ -151,7 +150,7 @@ async def build_indices_and_constraints(self): Parameters ---------- - None + self Returns ------- @@ -282,9 +281,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Extract entities as nodes - extracted_nodes = await extract_nodes( - self.llm_client, episode, previous_episodes, group_id - ) + extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') # Calculate Embeddings @@ -395,9 +392,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}') - episodic_edges: list[EpisodicEdge] = build_episodic_edges( - mentioned_nodes, episode, now, group_id - ) + episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now) logger.info(f'Built episodic edges: {episodic_edges}') @@ -413,10 +408,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): except Exception as e: raise e - async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], - ): + async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None): """ Process multiple episodes in bulk and update the graph. @@ -427,6 +419,8 @@ async def add_episode_bulk( ---------- bulk_episodes : list[RawEpisode] A list of RawEpisode objects to be processed and added to the graph. + group_id : str | None + An id for the graph partition the episode is a part of. Returns ------- @@ -463,6 +457,7 @@ async def add_episode_bulk( source=episode.source, content=episode.content, source_description=episode.source_description, + group_id=group_id, created_at=now, valid_at=episode.reference_time, ) @@ -599,7 +594,7 @@ async def _search( ) async def get_nodes_by_query( - self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT + self, query: str, group_ids: list[str] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. @@ -611,6 +606,8 @@ async def get_nodes_by_query( ---------- query : str The text query to search for in the graph. + group_ids : list[str] | None, optional + The graph partitions to return data from. limit : int | None, optional The maximum number of results to return per search method. If None, a default limit will be applied. @@ -635,5 +632,7 @@ async def get_nodes_by_query( """ embedder = self.llm_client.get_embedder() query_embedding = await generate_embedding(embedder, query) - relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit) + relevant_nodes = await hybrid_node_search( + [query], [query_embedding], self.driver, group_ids, limit + ) return relevant_nodes diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 4c8f12a4..b3c2bfda 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -62,7 +62,9 @@ async def retrieve_previous_episodes_bulk( ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: previous_episodes_list = await asyncio.gather( *[ - retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN) + retrieve_episodes( + driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id] + ) for episode in episodes ] ) @@ -90,7 +92,13 @@ async def extract_nodes_and_edges_bulk( extracted_edges_bulk = await asyncio.gather( *[ - extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]) + extract_edges( + llm_client, + episode, + extracted_nodes_bulk[i], + previous_episodes_list[i], + episode.group_id, + ) for i, episode in enumerate(episodes) ] ) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index d45718af..4518c8da 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -36,14 +36,13 @@ def build_episodic_edges( entity_nodes: List[EntityNode], episode: EpisodicNode, created_at: datetime, - group_id: str | None, ) -> List[EpisodicEdge]: edges: List[EpisodicEdge] = [ EpisodicEdge( source_node_uuid=episode.uuid, target_node_uuid=node.uuid, created_at=created_at, - group_id=group_id, + group_id=episode.group_id, ) for node in entity_nodes ] diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 6d857158..1aa6c757 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -70,7 +70,6 @@ async def extract_nodes( llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode], - group_id: str | None, ) -> list[EntityNode]: start = time() extracted_node_data: list[dict[str, Any]] = [] @@ -86,7 +85,7 @@ async def extract_nodes( for node_data in extracted_node_data: new_node = EntityNode( name=node_data['name'], - group_id=group_id, + group_id=episode.group_id, labels=node_data['labels'], summary=node_data['summary'], created_at=datetime.now(), From b946def64ac2a9103d1129b49a34ecc13befd338 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 11:15:26 -0400 Subject: [PATCH 05/12] mypy updates --- graphiti_core/graphiti.py | 15 +++-- graphiti_core/nodes.py | 2 +- graphiti_core/search/search.py | 2 +- graphiti_core/utils/bulk_utils.py | 22 ++++++- .../maintenance/graph_data_operations.py | 2 +- graphiti_core/utils/utils.py | 60 ------------------- 6 files changed, 33 insertions(+), 70 deletions(-) delete mode 100644 graphiti_core/utils/utils.py diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 952e570c..ad358d27 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -177,7 +177,7 @@ async def retrieve_episodes( self, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str] | None = None, + group_ids: list[str | None] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -191,7 +191,7 @@ async def retrieve_episodes( The reference time to retrieve episodes before. last_n : int, optional The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN. - group_ids : list[str], optional + group_ids : list[str | None], optional The group ids to return data from. Returns @@ -526,7 +526,7 @@ async def search( self, query: str, center_node_uuid: str | None = None, - group_ids: list[str] | None = None, + group_ids: list[str | None] | None = None, num_results=10, ): """ @@ -541,7 +541,7 @@ async def search( The search query string. center_node_uuid: str, optional Facts will be reranked based on proximity to this node - group_ids : list[str] | None, optional + group_ids : list[str | None] | None, optional The graph partitions to return data from. num_results : int, optional The maximum number of results to return. Defaults to 10. @@ -594,7 +594,10 @@ async def _search( ) async def get_nodes_by_query( - self, query: str, group_ids: list[str] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT + self, + query: str, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. @@ -606,7 +609,7 @@ async def get_nodes_by_query( ---------- query : str The text query to search for in the graph. - group_ids : list[str] | None, optional + group_ids : list[str | None] | None, optional The graph partitions to return data from. limit : int | None, optional The maximum number of results to return per search method. diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 1a6d6268..907d52b4 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -23,7 +23,6 @@ from uuid import uuid4 from neo4j import AsyncDriver -from openai import OpenAI from pydantic import BaseModel, Field from graphiti_core.llm_client.config import EMBEDDING_DIM @@ -238,6 +237,7 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode: created_at=record['created_at'].to_native().timestamp(), valid_at=(record['valid_at'].to_native()), uuid=record['uuid'], + group_id=record['group_id'], source=EpisodeType.from_str(record['source']), name=record['name'], source_description=record['source_description'], diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 87836765..3e4c59f1 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -52,7 +52,7 @@ class SearchConfig(BaseModel): num_edges: int = Field(default=10) num_nodes: int = Field(default=10) num_episodes: int = EPISODE_WINDOW_LEN - group_ids: list[str] | None + group_ids: list[str | None] | None search_methods: list[SearchMethod] reranker: Reranker | None diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index b3c2bfda..49bc2c6a 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -17,6 +17,7 @@ import asyncio import logging import typing +from collections import defaultdict from datetime import datetime from math import ceil @@ -42,7 +43,6 @@ extract_nodes, ) from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates -from graphiti_core.utils.utils import chunk_edges_by_nodes logger = logging.getLogger(__name__) @@ -351,3 +351,23 @@ async def extract_edge_dates_bulk( edge.expired_at = datetime.now() return edges + + +def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]: + # We only want to dedupe edges that are between the same pair of nodes + # We build a map of the edges based on their source and target nodes. + edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) + for edge in edges: + # We drop loop edges + if edge.source_node_uuid == edge.target_node_uuid: + continue + + # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution + pointers = [edge.source_node_uuid, edge.target_node_uuid] + pointers.sort() + + edge_chunk_map[pointers[0] + pointers[1]].append(edge) + + edge_chunks = [chunk for chunk in edge_chunk_map.values()] + + return edge_chunks diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index febe7035..a942a00b 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -90,7 +90,7 @@ async def retrieve_episodes( driver: AsyncDriver, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str] | None = None, + group_ids: list[str | None] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. diff --git a/graphiti_core/utils/utils.py b/graphiti_core/utils/utils.py deleted file mode 100644 index 97821279..00000000 --- a/graphiti_core/utils/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import logging -from collections import defaultdict - -from graphiti_core.edges import EntityEdge, EpisodicEdge -from graphiti_core.nodes import EntityNode, EpisodicNode - -logger = logging.getLogger(__name__) - - -def build_episodic_edges( - entity_nodes: list[EntityNode], episode: EpisodicNode -) -> list[EpisodicEdge]: - edges: list[EpisodicEdge] = [] - - for node in entity_nodes: - edges.append( - EpisodicEdge( - source_node_uuid=episode.uuid, - target_node_uuid=node.uuid, - created_at=episode.created_at, - ) - ) - - return edges - - -def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]: - # We only want to dedupe edges that are between the same pair of nodes - # We build a map of the edges based on their source and target nodes. - edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list) - for edge in edges: - # We drop loop edges - if edge.source_node_uuid == edge.target_node_uuid: - continue - - # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution - pointers = [edge.source_node_uuid, edge.target_node_uuid] - pointers.sort() - - edge_chunk_map[pointers[0] + pointers[1]].append(edge) - - edge_chunks = [chunk for chunk in edge_chunk_map.values()] - - return edge_chunks From eef15e1a3d077f14a132e0888db4d67d277140c1 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 11:20:15 -0400 Subject: [PATCH 06/12] remove unused imports --- graphiti_core/search/search_utils.py | 72 ++++++++++++++-------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 2a2b41bf..96af4015 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -3,12 +3,10 @@ import re from collections import defaultdict from time import time -from typing import Any from neo4j import AsyncDriver, Query from graphiti_core.edges import EntityEdge, get_entity_edge_from_record -from graphiti_core.helpers import parse_db_date from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record logger = logging.getLogger(__name__) @@ -38,12 +36,12 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) async def edge_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - source_node_uuid: str | None, - target_node_uuid: str | None, - group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: group_ids = group_ids if group_ids is not None else [None] # vector similarity search over embedded facts @@ -147,10 +145,10 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + search_vector: list[float], + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: group_ids = group_ids if group_ids is not None else [None] @@ -179,10 +177,10 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + query: str, + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: group_ids = group_ids if group_ids is not None else [None] @@ -213,12 +211,12 @@ async def entity_fulltext_search( async def edge_fulltext_search( - driver: AsyncDriver, - query: str, - source_node_uuid: str | None, - target_node_uuid: str | None, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + query: str, + source_node_uuid: str | None, + target_node_uuid: str | None, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: group_ids = group_ids if group_ids is not None else [None] @@ -325,11 +323,11 @@ async def edge_fulltext_search( async def hybrid_node_search( - queries: list[str], - embeddings: list[list[float]], - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + queries: list[str], + embeddings: list[list[float]], + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Perform a hybrid search for nodes using both text queries and embeddings. @@ -393,8 +391,8 @@ async def hybrid_node_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: """ Retrieve relevant nodes based on the provided list of EntityNodes. @@ -431,11 +429,11 @@ async def get_relevant_nodes( async def get_relevant_edges( - driver: AsyncDriver, - edges: list[EntityEdge], - source_node_uuid: str | None, - target_node_uuid: str | None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + edges: list[EntityEdge], + source_node_uuid: str | None, + target_node_uuid: str | None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -492,7 +490,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: # use rrf as a preliminary ranker sorted_uuids = rrf(results) @@ -514,8 +512,8 @@ async def node_distance_reranker( for record in records: if ( - record['source_uuid'] == center_node_uuid - or record['target_uuid'] == center_node_uuid + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid ): continue distance = record['score'] From 5b669f337873c4ff29cace1a8ecbd6fc75763ce3 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 11:25:38 -0400 Subject: [PATCH 07/12] update unit tests --- graphiti_core/search/search_utils.py | 70 +++++++++---------- .../maintenance/test_temporal_operations.py | 17 ++++- .../test_temporal_operations_int.py | 21 ++++-- tests/utils/search/search_utils_test.py | 29 +++++--- 4 files changed, 84 insertions(+), 53 deletions(-) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 96af4015..5b63d300 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -36,12 +36,12 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) async def edge_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - source_node_uuid: str | None, - target_node_uuid: str | None, - group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: group_ids = group_ids if group_ids is not None else [None] # vector similarity search over embedded facts @@ -145,10 +145,10 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + search_vector: list[float], + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: group_ids = group_ids if group_ids is not None else [None] @@ -177,10 +177,10 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + query: str, + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: group_ids = group_ids if group_ids is not None else [None] @@ -211,12 +211,12 @@ async def entity_fulltext_search( async def edge_fulltext_search( - driver: AsyncDriver, - query: str, - source_node_uuid: str | None, - target_node_uuid: str | None, - group_ids: list[str | None] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + query: str, + source_node_uuid: str | None, + target_node_uuid: str | None, + group_ids: list[str | None] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: group_ids = group_ids if group_ids is not None else [None] @@ -323,11 +323,11 @@ async def edge_fulltext_search( async def hybrid_node_search( - queries: list[str], - embeddings: list[list[float]], - driver: AsyncDriver, - group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + queries: list[str], + embeddings: list[list[float]], + driver: AsyncDriver, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Perform a hybrid search for nodes using both text queries and embeddings. @@ -391,8 +391,8 @@ async def hybrid_node_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: """ Retrieve relevant nodes based on the provided list of EntityNodes. @@ -429,11 +429,11 @@ async def get_relevant_nodes( async def get_relevant_edges( - driver: AsyncDriver, - edges: list[EntityEdge], - source_node_uuid: str | None, - target_node_uuid: str | None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + edges: list[EntityEdge], + source_node_uuid: str | None, + target_node_uuid: str | None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -490,7 +490,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: # use rrf as a preliminary ranker sorted_uuids = rrf(results) @@ -512,8 +512,8 @@ async def node_distance_reranker( for record in records: if ( - record['source_uuid'] == center_node_uuid - or record['target_uuid'] == center_node_uuid + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid ): continue distance = record['score'] diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py index 76224bc5..72497d56 100644 --- a/tests/utils/maintenance/test_temporal_operations.py +++ b/tests/utils/maintenance/test_temporal_operations.py @@ -135,9 +135,9 @@ def test_prepare_invalidation_context(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') + node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1') # Create edges edge1 = EntityEdge( @@ -147,6 +147,7 @@ def test_prepare_invalidation_context(): name='KNOWS', fact='Node1 knows Node2', created_at=now, + group_id='1', ) edge2 = EntityEdge( uuid='e2', @@ -155,6 +156,7 @@ def test_prepare_invalidation_context(): name='LIKES', fact='Node2 likes Node3', created_at=now, + group_id='1', ) # Create NodeEdgeNodeTriplet objects @@ -173,6 +175,7 @@ def test_prepare_invalidation_context(): valid_at=now, source=EpisodeType.message, source_description='Test episode for unit testing', + group_id='1', ) previous_episodes = [ EpisodicNode( @@ -182,6 +185,7 @@ def test_prepare_invalidation_context(): valid_at=now - timedelta(days=1), source=EpisodeType.message, source_description='Test previous episode 1 for unit testing', + group_id='1', ), EpisodicNode( name='Previous Episode 2', @@ -190,6 +194,7 @@ def test_prepare_invalidation_context(): valid_at=now - timedelta(days=2), source=EpisodeType.message, source_description='Test previous episode 2 for unit testing', + group_id='1', ), ] @@ -235,6 +240,7 @@ def test_prepare_invalidation_context_empty_input(): valid_at=now, source=EpisodeType.message, source_description='Test empty episode for unit testing', + group_id='1', ) result = prepare_invalidation_context([], [], current_episode, []) assert isinstance(result, dict) @@ -263,6 +269,7 @@ def test_prepare_invalidation_context_sorting(): name='KNOWS', fact='Node1 knows Node2', created_at=now, + group_id='1', ) edge2 = EntityEdge( uuid='e2', @@ -271,6 +278,7 @@ def test_prepare_invalidation_context_sorting(): name='LIKES', fact='Node2 likes Node1', created_at=now + timedelta(hours=1), + group_id='1', ) edge_with_nodes1 = (node1, edge1, node2) @@ -287,6 +295,7 @@ def test_prepare_invalidation_context_sorting(): valid_at=now, source=EpisodeType.message, source_description='Test episode for unit testing', + group_id='1', ) previous_episodes = [ EpisodicNode( @@ -296,6 +305,7 @@ def test_prepare_invalidation_context_sorting(): valid_at=now - timedelta(days=1), source=EpisodeType.message, source_description='Test previous episode for unit testing', + group_id='1', ), ] @@ -321,6 +331,7 @@ def generate_entity_edge(self, valid_at, invalid_at): created_at=datetime.now(), valid_at=valid_at, invalid_at=invalid_at, + group_id='1', ) def test_both_dates_present(self): diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 9e6b2953..b08689fb 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -76,6 +76,7 @@ def create_test_data(): valid_at=now, source=EpisodeType.message, source_description='Test episode for unit testing', + group_id='1', ) # Create previous episodes @@ -87,6 +88,7 @@ def create_test_data(): valid_at=now - timedelta(days=1), source=EpisodeType.message, source_description='Test previous episode for unit testing', + group_id='1', ) ] @@ -142,10 +144,12 @@ def create_complex_test_data(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now) - node4 = EntityNode(uuid='4', name='Company XYZ', labels=['Organization'], created_at=now) + node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1') + node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1') + node4 = EntityNode( + uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1' + ) # Create edges edge1 = EntityEdge( @@ -154,6 +158,7 @@ def create_complex_test_data(): target_node_uuid='2', name='LIKES', fact='Alice likes Bob', + group_id='1', created_at=now - timedelta(days=5), ) edge2 = EntityEdge( @@ -162,6 +167,7 @@ def create_complex_test_data(): target_node_uuid='3', name='FRIENDS_WITH', fact='Alice is friends with Charlie', + group_id='1', created_at=now - timedelta(days=3), ) edge3 = EntityEdge( @@ -170,6 +176,7 @@ def create_complex_test_data(): target_node_uuid='4', name='WORKS_FOR', fact='Bob works for Company XYZ', + group_id='1', created_at=now - timedelta(days=2), ) @@ -199,6 +206,7 @@ async def test_invalidate_edges_complex(): target_node_uuid='2', name='DISLIKES', fact='Alice dislikes Bob', + group_id='1', created_at=datetime.now(), ), nodes[1], @@ -225,6 +233,7 @@ async def test_invalidate_edges_temporal_update(): target_node_uuid='4', name='LEFT_JOB', fact='Bob left his job at Company XYZ', + group_id='1', created_at=datetime.now(), ), nodes[3], @@ -251,6 +260,7 @@ async def test_invalidate_edges_multiple_invalidations(): target_node_uuid='2', name='ENEMIES_WITH', fact='Alice and Bob are now enemies', + group_id='1', created_at=datetime.now(), ), nodes[1], @@ -263,6 +273,7 @@ async def test_invalidate_edges_multiple_invalidations(): target_node_uuid='3', name='ENDED_FRIENDSHIP', fact='Alice ended her friendship with Charlie', + group_id='1', created_at=datetime.now(), ), nodes[2], @@ -292,6 +303,7 @@ async def test_invalidate_edges_no_effect(): target_node_uuid='4', name='APPLIED_TO', fact='Charlie applied to Company XYZ', + group_id='1', created_at=datetime.now(), ), nodes[3], @@ -316,6 +328,7 @@ async def test_invalidate_edges_partial_update(): target_node_uuid='4', name='CHANGED_POSITION', fact='Bob changed his position at Company XYZ', + group_id='1', created_at=datetime.now(), ), nodes[3], diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index e4760976..2ce0a6b9 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -70,7 +70,9 @@ async def test_hybrid_node_search_only_fulltext(): ) as mock_fulltext_search, patch( 'graphiti_core.search.search_utils.entity_similarity_search' ) as mock_similarity_search: - mock_fulltext_search.return_value = [EntityNode(uuid='1', name='Alice', labels=['Entity'])] + mock_fulltext_search.return_value = [ + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1') + ] mock_similarity_search.return_value = [] queries = ['Alice'] @@ -93,18 +95,23 @@ async def test_hybrid_node_search_with_limit(): 'graphiti_core.search.search_utils.entity_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [ - EntityNode(uuid='1', name='Alice', labels=['Entity']), - EntityNode(uuid='2', name='Bob', labels=['Entity']), + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), + EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'), ] mock_similarity_search.return_value = [ - EntityNode(uuid='3', name='Charlie', labels=['Entity']), - EntityNode(uuid='4', name='David', labels=['Entity']), + EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'), + EntityNode( + uuid='4', + name='David', + labels=['Entity'], + group_id='1', + ), ] queries = ['Test'] embeddings = [[0.1, 0.2, 0.3]] limit = 1 - results = await hybrid_node_search(queries, embeddings, mock_driver, limit) + results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) # We expect 4 results because the limit is applied per search method # before deduplication, and we're not actually limiting the results @@ -127,18 +134,18 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): 'graphiti_core.search.search_utils.entity_similarity_search' ) as mock_similarity_search: mock_fulltext_search.return_value = [ - EntityNode(uuid='1', name='Alice', labels=['Entity']), - EntityNode(uuid='2', name='Bob', labels=['Entity']), + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), + EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'), ] mock_similarity_search.return_value = [ - EntityNode(uuid='1', name='Alice', labels=['Entity']), # Duplicate - EntityNode(uuid='3', name='Charlie', labels=['Entity']), + EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), # Duplicate + EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'), ] queries = ['Test'] embeddings = [[0.1, 0.2, 0.3]] limit = 2 - results = await hybrid_node_search(queries, embeddings, mock_driver, limit) + results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) # We expect 3 results because: # 1. The limit of 2 is applied to each search method From cd36604176dce84070817a6b734d30cd85af8f46 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 11:36:28 -0400 Subject: [PATCH 08/12] unit tests --- .../maintenance/test_temporal_operations.py | 14 +++++++++----- tests/utils/search/search_utils_test.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py index 72497d56..0e86bd89 100644 --- a/tests/utils/maintenance/test_temporal_operations.py +++ b/tests/utils/maintenance/test_temporal_operations.py @@ -33,9 +33,9 @@ def create_test_data(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') + node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1') # Create edges existing_edge1 = EntityEdge( @@ -45,6 +45,7 @@ def create_test_data(): name='KNOWS', fact='Node1 knows Node2', created_at=now, + group_id='1', ) existing_edge2 = EntityEdge( uuid='e2', @@ -53,6 +54,7 @@ def create_test_data(): name='LIKES', fact='Node2 likes Node3', created_at=now, + group_id='1', ) new_edge1 = EntityEdge( uuid='e3', @@ -61,6 +63,7 @@ def create_test_data(): name='WORKS_WITH', fact='Node1 works with Node3', created_at=now, + group_id='1', ) new_edge2 = EntityEdge( uuid='e4', @@ -69,6 +72,7 @@ def create_test_data(): name='DISLIKES', fact='Node1 dislikes Node2', created_at=now, + group_id='1', ) return { @@ -258,8 +262,8 @@ def test_prepare_invalidation_context_sorting(): now = datetime.now() # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1') + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1') # Create edges with different timestamps edge1 = EntityEdge( diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index 2ce0a6b9..38837f0d 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -19,12 +19,12 @@ async def test_hybrid_node_search_deduplication(): ) as mock_similarity_search: # Set up mock return values mock_fulltext_search.side_effect = [ - [EntityNode(uuid='1', name='Alice', labels=['Entity'])], - [EntityNode(uuid='2', name='Bob', labels=['Entity'])], + [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')], + [EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')], ] mock_similarity_search.side_effect = [ - [EntityNode(uuid='1', name='Alice', labels=['Entity'])], - [EntityNode(uuid='3', name='Charlie', labels=['Entity'])], + [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')], + [EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')], ] # Call the function with test data @@ -120,8 +120,8 @@ async def test_hybrid_node_search_with_limit(): assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 # Verify that the limit was passed to the search functions - mock_fulltext_search.assert_called_with('Test', mock_driver, 2) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2) + mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2) + mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2) @pytest.mark.asyncio @@ -155,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 - mock_fulltext_search.assert_called_with('Test', mock_driver, 4) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4) + mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4) + mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4) From 81b69ad5583010cfd9fce10375b581720a5ba340 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 12:18:42 -0400 Subject: [PATCH 09/12] add optional uuid field --- graphiti_core/graphiti.py | 52 ++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index ad358d27..e8e1a5ee 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -174,10 +174,10 @@ async def build_indices_and_constraints(self): await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str | None] | None = None, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str | None] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -207,13 +207,14 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n, group_ids) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - group_id: str | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + group_id: str | None = None, + uuid: str = None ): """ Process an episode and update the graph. @@ -278,6 +279,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): created_at=now, valid_at=reference_time, ) + episode.uuid = episode.uuid if uuid is None else uuid # Extract entities as nodes @@ -523,11 +525,11 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str raise e async def search( - self, - query: str, - center_node_uuid: str | None = None, - group_ids: list[str | None] | None = None, - num_results=10, + self, + query: str, + center_node_uuid: str | None = None, + group_ids: list[str | None] | None = None, + num_results=10, ): """ Perform a hybrid search on the knowledge graph. @@ -583,21 +585,21 @@ async def search( return edges async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid ) async def get_nodes_by_query( - self, - query: str, - group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + self, + query: str, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. From ea04b83a1526e91b85afd31201d26d38f4d23725 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 12:19:17 -0400 Subject: [PATCH 10/12] format --- graphiti_core/graphiti.py | 52 +++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index e8e1a5ee..bcac2ddd 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -174,10 +174,10 @@ async def build_indices_and_constraints(self): await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str | None] | None = None, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str | None] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -207,14 +207,14 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n, group_ids) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - group_id: str | None = None, - uuid: str = None + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + group_id: str | None = None, + uuid: str = None, ): """ Process an episode and update the graph. @@ -525,11 +525,11 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str raise e async def search( - self, - query: str, - center_node_uuid: str | None = None, - group_ids: list[str | None] | None = None, - num_results=10, + self, + query: str, + center_node_uuid: str | None = None, + group_ids: list[str | None] | None = None, + num_results=10, ): """ Perform a hybrid search on the knowledge graph. @@ -585,21 +585,21 @@ async def search( return edges async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid ) async def get_nodes_by_query( - self, - query: str, - group_ids: list[str | None] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + self, + query: str, + group_ids: list[str | None] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. From 7bfa768aa8bae55be1a67fec0ac3ae820c594929 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 12:25:48 -0400 Subject: [PATCH 11/12] mypy --- graphiti_core/graphiti.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index bcac2ddd..a2a79d0c 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -214,7 +214,7 @@ async def add_episode( reference_time: datetime, source: EpisodeType = EpisodeType.message, group_id: str | None = None, - uuid: str = None, + uuid: str | None = None, ): """ Process an episode and update the graph. @@ -236,6 +236,8 @@ async def add_episode( The type of the episode. Defaults to EpisodeType.message. group_id : str | None An id for the graph partition the episode is a part of. + uuid : str | None + Optional uuid of the episode. Returns ------- From 9478e4fd39cde2a78d54b0d8acb477e244572c57 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 6 Sep 2024 12:29:53 -0400 Subject: [PATCH 12/12] ellipsis --- graphiti_core/graphiti.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index a2a79d0c..0b4e6325 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -281,7 +281,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): created_at=now, valid_at=reference_time, ) - episode.uuid = episode.uuid if uuid is None else uuid + episode.uuid = uuid if uuid is not None else episode.uuid # Extract entities as nodes