Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group ids #89

Merged
merged 12 commits into from
Sep 6, 2024
Merged
1 change: 1 addition & 0 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def main(use_bulk: bool = True):
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
)
return

Expand Down
71 changes: 39 additions & 32 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from time import time
from typing import Any
from uuid import uuid4

from neo4j import AsyncDriver
Expand All @@ -32,6 +33,7 @@

class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid4().hex)
group_id: str | None = Field(description='partition of the graph')
source_node_uuid: str
target_node_uuid: str
created_at: datetime
Expand Down Expand Up @@ -61,11 +63,12 @@ async def save(self, driver: AsyncDriver):
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at}
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)

Expand All @@ -92,25 +95,16 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid As uuid,
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuid=uuid,
)

edges: list[EpisodicEdge] = []

for record in records:
edges.append(
EpisodicEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
)
edges = [get_episodic_edge_from_record(record) for record in records]

logger.info(f'Found Edge: {uuid}')

Expand Down Expand Up @@ -153,14 +147,15 @@ async def save(self, driver: AsyncDriver):
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""",
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
fact=self.fact,
fact_embedding=self.fact_embedding,
episodes=self.episodes,
Expand Down Expand Up @@ -198,6 +193,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
Expand All @@ -208,25 +204,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
uuid=uuid,
)

edges: list[EntityEdge] = []

for record in records:
edges.append(
EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
)
edges = [get_entity_edge_from_record(record) for record in records]

logger.info(f'Found Edge: {uuid}')

return edges[0]


# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
return EpisodicEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)


def get_entity_edge_from_record(record: Any) -> EntityEdge:
return EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
75 changes: 45 additions & 30 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
from datetime import datetime
from time import time
from typing import Callable

from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
Expand Down Expand Up @@ -120,7 +119,7 @@ def close(self):

Parameters
----------
None
self

Returns
-------
Expand Down Expand Up @@ -151,7 +150,7 @@ async def build_indices_and_constraints(self):

Parameters
----------
None
self

Returns
-------
Expand All @@ -178,6 +177,7 @@ async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand All @@ -191,6 +191,8 @@ async def retrieve_episodes(
The reference time to retrieve episodes before.
last_n : int, optional
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
group_ids : list[str | None], optional
The group ids to return data from.

Returns
-------
Expand All @@ -202,7 +204,7 @@ async def retrieve_episodes(
The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module.
"""
return await retrieve_episodes(self.driver, reference_time, last_n)
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)

async def add_episode(
self,
Expand All @@ -211,8 +213,8 @@ async def add_episode(
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
group_id: str | None = None,
uuid: str | None = None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assignment episode.uuid = episode.uuid if uuid is None else uuid is redundant and potentially incorrect. It should be episode.uuid = uuid if uuid is not None else episode.uuid to ensure episode.uuid is only overwritten if uuid is provided.

"""
Process an episode and update the graph.
Expand All @@ -232,10 +234,10 @@ async def add_episode(
The reference time for the episode.
source : EpisodeType, optional
The type of the episode. Defaults to EpisodeType.message.
success_callback : Callable | None, optional
A callback function to be called upon successful processing.
error_callback : Callable | None, optional
A callback function to be called if an error occurs during processing.
group_id : str | None
An id for the graph partition the episode is a part of.
uuid : str | None
Optional uuid of the episode.

Returns
-------
Expand Down Expand Up @@ -266,16 +268,20 @@ async def add_episode_endpoint(episode_data: EpisodeData):
embedder = self.llm_client.get_embedder()
now = datetime.now()

previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
previous_episodes = await self.retrieve_episodes(
reference_time, last_n=3, group_ids=[group_id]
)
episode = EpisodicNode(
name=name,
group_id=group_id,
labels=[],
source=source,
content=episode_body,
source_description=source_description,
created_at=now,
valid_at=reference_time,
)
episode.uuid = uuid if uuid is not None else episode.uuid

# Extract entities as nodes

Expand All @@ -299,7 +305,9 @@ async def add_episode_endpoint(episode_data: EpisodeData):

(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
),
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)
Expand Down Expand Up @@ -388,11 +396,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):

logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')

episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
episode,
now,
)
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)

logger.info(f'Built episodic edges: {episodic_edges}')

Expand All @@ -405,18 +409,10 @@ async def add_episode_endpoint(episode_data: EpisodeData):
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')

if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e
raise e

async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
):
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
"""
Process multiple episodes in bulk and update the graph.

Expand All @@ -427,6 +423,8 @@ async def add_episode_bulk(
----------
bulk_episodes : list[RawEpisode]
A list of RawEpisode objects to be processed and added to the graph.
group_id : str | None
An id for the graph partition the episode is a part of.

Returns
-------
Expand Down Expand Up @@ -463,6 +461,7 @@ async def add_episode_bulk(
source=episode.source,
content=episode.content,
source_description=episode.source_description,
group_id=group_id,
created_at=now,
valid_at=episode.reference_time,
)
Expand Down Expand Up @@ -527,7 +526,13 @@ async def add_episode_bulk(
except Exception as e:
raise e

async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
async def search(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=10,
):
"""
Perform a hybrid search on the knowledge graph.

Expand All @@ -540,6 +545,8 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
num_results : int, optional
The maximum number of results to return. Defaults to 10.

Expand All @@ -562,6 +569,7 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
num_episodes=0,
num_edges=num_results,
num_nodes=0,
group_ids=group_ids,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
Expand Down Expand Up @@ -590,7 +598,10 @@ async def _search(
)

async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self,
query: str,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand All @@ -602,6 +613,8 @@ async def get_nodes_by_query(
----------
query : str
The text query to search for in the graph.
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int | None, optional
The maximum number of results to return per search method.
If None, a default limit will be applied.
Expand All @@ -626,5 +639,7 @@ async def get_nodes_by_query(
"""
embedder = self.llm_client.get_embedder()
query_embedding = await generate_embedding(embedder, query)
relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
relevant_nodes = await hybrid_node_search(
[query], [query_embedding], self.driver, group_ids, limit
)
return relevant_nodes
Loading