diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index b9c090fc..91936d6a 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -18,11 +18,12 @@ from abc import ABC, abstractmethod from datetime import datetime from time import time -from typing import Any, LiteralString +from typing import Any from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from typing_extensions import LiteralString from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index fe21b4dc..4c0e2c15 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -19,11 +19,12 @@ from datetime import datetime, timezone from enum import Enum from time import time -from typing import Any, LiteralString +from typing import Any from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from typing_extensions import LiteralString from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 32eef8bf..27e9a456 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) -RELEVANT_SCHEMA_LIMIT = 3 +RELEVANT_SCHEMA_LIMIT = 10 DEFAULT_MIN_SCORE = 0.6 DEFAULT_MMR_LAMBDA = 0.5 MAX_SEARCH_DEPTH = 3 @@ -605,6 +605,28 @@ async def get_relevant_edges( relevant_edge_uuids.add(edge.uuid) relevant_edges.append(edge) + query: LiteralString = """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: $source_uuid})-[r:RELATES_TO {group_id: edge.group_id}]->(m:Entity {uuid: $target_uuid}) + WITH n, m, r, vector.similarity.cosine(r.fact_embedding, edge.fact_embedding) AS score + WHERE score > $min_score + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + LIMIT $limit + """ + end = time() logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')