diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index f541fafa..726ff46b 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -631,7 +631,7 @@ async def node_distance_reranker( ) -> list[str]: # filter out node_uuid center node node uuid filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids)) - scores: dict[str, float] = {} + scores: dict[str, float] = {center_node_uuid: 0.0} # Find the shortest path to center node query = Query(""" @@ -649,9 +649,13 @@ async def node_distance_reranker( for result in path_results: uuid = result['uuid'] - score = result['score'] if 'score' in result else float('inf') + score = result['score'] scores[uuid] = score + for uuid in filtered_uuids: + if uuid not in scores: + scores[uuid] = float('inf') + # rerank on shortest distance filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index e42a24d4..7bc82550 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -72,6 +72,7 @@ async def test_graphiti_init(): COMBINED_HYBRID_SEARCH_CROSS_ENCODER, group_ids=['test'], ) + pretty_results = { 'edges': [edge.fact for edge in results.edges], 'nodes': [node.name for node in results.nodes],