Skip to content

Commit

Permalink
chore: Make deleting groups safer (#155)
Browse files Browse the repository at this point in the history
* chore: Make deleting groups safer

* chore: Use appropriate errors in delete group checks

* chore: Add GroupsEdgesNotFound error type
  • Loading branch information
paul-paliychuk authored Sep 25, 2024
1 parent bca838f commit b537cf5
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
8 changes: 3 additions & 5 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from neo4j import AsyncDriver
from pydantic import BaseModel, Field

from graphiti_core.errors import EdgeNotFoundError
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.nodes import Node
Expand Down Expand Up @@ -147,10 +147,9 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
)

edges = [get_episodic_edge_from_record(record) for record in records]
uuids = [edge.uuid for edge in edges]

if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
raise GroupsEdgesNotFoundError(group_ids)
return edges


Expand Down Expand Up @@ -293,10 +292,9 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
)

edges = [get_entity_edge_from_record(record) for record in records]
uuids = [edge.uuid for edge in edges]

if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
raise GroupsEdgesNotFoundError(group_ids)
return edges


Expand Down
8 changes: 8 additions & 0 deletions graphiti_core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def __init__(self, uuid: str):
super().__init__(self.message)


class GroupsEdgesNotFoundError(GraphitiError):
"""Raised when no edges are found for a list of group ids."""

def __init__(self, group_ids: list[str]):
self.message = f'no edges found for group ids {group_ids}'
super().__init__(self.message)


class NodeNotFoundError(GraphitiError):
"""Raised when a node is not found."""

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.5"
version = "0.3.6"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <[email protected]>",
Expand Down
33 changes: 20 additions & 13 deletions server/graph_service/zep_graphiti.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import logging
from typing import Annotated

from fastapi import Depends, HTTPException
from graphiti_core import Graphiti # type: ignore
from graphiti_core.edges import EntityEdge # type: ignore
from graphiti_core.errors import EdgeNotFoundError, NodeNotFoundError # type: ignore
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError, NodeNotFoundError
from graphiti_core.llm_client import LLMClient # type: ignore
from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore

from graph_service.config import ZepEnvDep
from graph_service.dto import FactResult

logger = logging.getLogger(__name__)


class ZepGraphiti(Graphiti):
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
Expand All @@ -36,18 +39,22 @@ async def get_entity_edge(self, uuid: str):
async def delete_group(self, group_id: str):
try:
edges = await EntityEdge.get_by_group_ids(self.driver, [group_id])
nodes = await EntityNode.get_by_group_ids(self.driver, [group_id])
episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id])
for edge in edges:
await edge.delete(self.driver)
for node in nodes:
await node.delete(self.driver)
for episode in episodes:
await episode.delete(self.driver)
except EdgeNotFoundError as e:
raise HTTPException(status_code=404, detail=e.message) from e
except NodeNotFoundError as e:
raise HTTPException(status_code=404, detail=e.message) from e
except GroupsEdgesNotFoundError:
logger.warning(f'No edges found for group {group_id}')
edges = []

nodes = await EntityNode.get_by_group_ids(self.driver, [group_id])

episodes = await EpisodicNode.get_by_group_ids(self.driver, [group_id])

for edge in edges:
await edge.delete(self.driver)

for node in nodes:
await node.delete(self.driver)

for episode in episodes:
await episode.delete(self.driver)

async def delete_entity_edge(self, uuid: str):
try:
Expand Down

0 comments on commit b537cf5

Please sign in to comment.