From 00fe87679e957d18e82bbaf9170b02a8a15df996 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:08:18 -0500 Subject: [PATCH] Bounded semaphore - limiting concurrency (#244) * WIP * add semaphore * remove unused imports * remove unused imports * lower concurrency limit --- .../msc_eval.py | 3 +- .../msc_runner.py | 3 +- .../cross_encoder/openai_reranker_client.py | 4 +- graphiti_core/graphiti.py | 39 +++++++++---------- graphiti_core/helpers.py | 19 +++++++++ graphiti_core/search/search.py | 10 ++--- graphiti_core/search/search_utils.py | 6 +-- graphiti_core/utils/bulk_utils.py | 26 +++++++------ .../utils/maintenance/community_operations.py | 10 +++-- .../utils/maintenance/edge_operations.py | 7 ++-- .../maintenance/graph_data_operations.py | 7 ++-- .../utils/maintenance/node_operations.py | 7 ++-- tests/test_graphiti_int.py | 10 ++--- 13 files changed, 87 insertions(+), 64 deletions(-) diff --git a/examples/multi_session_conversation_memory/msc_eval.py b/examples/multi_session_conversation_memory/msc_eval.py index 7c2ac0b7..db61482b 100644 --- a/examples/multi_session_conversation_memory/msc_eval.py +++ b/examples/multi_session_conversation_memory/msc_eval.py @@ -25,6 +25,7 @@ from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a from graphiti_core import Graphiti +from graphiti_core.helpers import semaphore_gather from graphiti_core.prompts import prompt_library from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF @@ -122,7 +123,7 @@ async def main(): qa_chunk = qa[i : i + 20] group_ids = range(len(qa))[i : i + 20] results = list( - await asyncio.gather( + await semaphore_gather( *[ evaluate_qa(graphiti, str(group_id), query, answer) for group_id, (query, answer) in zip(group_ids, qa_chunk) diff --git a/examples/multi_session_conversation_memory/msc_runner.py b/examples/multi_session_conversation_memory/msc_runner.py index 81f82b06..2cef9c58 100644 --- a/examples/multi_session_conversation_memory/msc_runner.py +++ b/examples/multi_session_conversation_memory/msc_runner.py @@ -26,6 +26,7 @@ parse_msc_messages, ) from graphiti_core import Graphiti +from graphiti_core.helpers import semaphore_gather load_dotenv() @@ -75,7 +76,7 @@ async def main(): msc_message_slice = msc_messages[i : i + 10] group_ids = range(len(msc_messages))[i : i + 10] - await asyncio.gather( + await semaphore_gather( *[ add_conversation(graphiti, str(group_id), messages) for group_id, messages in zip(group_ids, msc_message_slice) diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index f4f2e95d..e41cb61e 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -14,7 +14,6 @@ limitations under the License. """ -import asyncio import logging from typing import Any @@ -22,6 +21,7 @@ from openai import AsyncOpenAI from pydantic import BaseModel +from ..helpers import semaphore_gather from ..llm_client import LLMConfig, RateLimitError from ..prompts import Message from .client import CrossEncoderClient @@ -75,7 +75,7 @@ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]] for passage in passages ] try: - responses = await asyncio.gather( + responses = await semaphore_gather( *[ self.client.chat.completions.create( model=DEFAULT_MODEL, diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 574d6189..c2693b4d 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -14,7 +14,6 @@ limitations under the License. """ -import asyncio import logging from datetime import datetime from time import time @@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder -from graphiti_core.helpers import DEFAULT_DATABASE +from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search @@ -340,13 +339,13 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Calculate Embeddings - await asyncio.gather( + await semaphore_gather( *[node.generate_name_embedding(self.embedder) for node in extracted_nodes] ) # Find relevant nodes already in the graph existing_nodes_lists: list[list[EntityNode]] = list( - await asyncio.gather( + await semaphore_gather( *[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes] ) ) @@ -354,7 +353,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Resolve extracted nodes with nodes already in the graph and extract facts logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') - (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather( + (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather( resolve_extracted_nodes( self.llm_client, extracted_nodes, @@ -374,7 +373,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) # calculate embeddings - await asyncio.gather( + await semaphore_gather( *[ edge.generate_embedding(self.embedder) for edge in extracted_edges_with_resolved_pointers @@ -383,7 +382,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Resolve extracted edges with related edges already in the graph related_edges_list: list[list[EntityEdge]] = list( - await asyncio.gather( + await semaphore_gather( *[ get_relevant_edges( self.driver, @@ -404,7 +403,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) existing_source_edges_list: list[list[EntityEdge]] = list( - await asyncio.gather( + await semaphore_gather( *[ get_relevant_edges( self.driver, @@ -419,7 +418,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) existing_target_edges_list: list[list[EntityEdge]] = list( - await asyncio.gather( + await semaphore_gather( *[ get_relevant_edges( self.driver, @@ -468,7 +467,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Update any communities if update_communities: - await asyncio.gather( + await semaphore_gather( *[ update_community(self.driver, self.llm_client, self.embedder, node) for node in nodes @@ -538,7 +537,7 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str ] # Save all the episodes - await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) + await semaphore_gather(*[episode.save(self.driver) for episode in episodes]) # Get previous episode context for each episode episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) @@ -551,19 +550,19 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) # Generate embeddings - await asyncio.gather( + await semaphore_gather( *[node.generate_name_embedding(self.embedder) for node in extracted_nodes], *[edge.generate_embedding(self.embedder) for edge in extracted_edges], ) # Dedupe extracted nodes, compress extracted edges - (nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather( + (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather( dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes), extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs), ) # save nodes to KG - await asyncio.gather(*[node.save(self.driver) for node in nodes]) + await semaphore_gather(*[node.save(self.driver) for node in nodes]) # re-map edge pointers so that they don't point to discard dupe nodes extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( @@ -574,7 +573,7 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str ) # save episodic edges to KG - await asyncio.gather( + await semaphore_gather( *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers] ) @@ -587,7 +586,7 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str # invalidate edges # save edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + await semaphore_gather(*[edge.save(self.driver) for edge in edges]) end = time() logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms') @@ -610,12 +609,12 @@ async def build_communities(self, group_ids: list[str] | None = None) -> list[Co self.driver, self.llm_client, group_ids ) - await asyncio.gather( + await semaphore_gather( *[node.generate_name_embedding(self.embedder) for node in community_nodes] ) - await asyncio.gather(*[node.save(self.driver) for node in community_nodes]) - await asyncio.gather(*[edge.save(self.driver) for edge in community_edges]) + await semaphore_gather(*[node.save(self.driver) for node in community_nodes]) + await semaphore_gather(*[edge.save(self.driver) for edge in community_edges]) return community_nodes @@ -698,7 +697,7 @@ async def _search( async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults: episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids) - edges_list = await asyncio.gather( + edges_list = await semaphore_gather( *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes] ) diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 1e7bec50..fb330e67 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -14,7 +14,9 @@ limitations under the License. """ +import asyncio import os +from collections.abc import Coroutine from datetime import datetime import numpy as np @@ -25,6 +27,7 @@ DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) +SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20)) MAX_REFLEXION_ITERATIONS = 2 DEFAULT_PAGE_LIMIT = 20 @@ -80,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]: else: norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True) return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist() + + +# Use this instead of asyncio.gather() to bound coroutines +async def semaphore_gather( + *coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True +): + semaphore = asyncio.Semaphore(max_coroutines) + + async def _wrap_coroutine(coroutine): + async with semaphore: + return await coroutine + + return await asyncio.gather( + *(_wrap_coroutine(coroutine) for coroutine in coroutines), + return_exceptions=return_exceptions, + ) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 6d065e85..c1b134c3 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -14,7 +14,6 @@ limitations under the License. """ -import asyncio import logging from collections import defaultdict from time import time @@ -25,6 +24,7 @@ from graphiti_core.edges import EntityEdge from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import SearchRerankerError +from graphiti_core.helpers import semaphore_gather from graphiti_core.nodes import CommunityNode, EntityNode from graphiti_core.search.search_config import ( DEFAULT_SEARCH_LIMIT, @@ -78,7 +78,7 @@ async def search( # if group_ids is empty, set it to None group_ids = group_ids if group_ids else None - edges, nodes, communities = await asyncio.gather( + edges, nodes, communities = await semaphore_gather( edge_search( driver, cross_encoder, @@ -141,7 +141,7 @@ async def edge_search( return [] search_results: list[list[EntityEdge]] = list( - await asyncio.gather( + await semaphore_gather( *[ edge_fulltext_search(driver, query, group_ids, 2 * limit), edge_similarity_search( @@ -226,7 +226,7 @@ async def node_search( return [] search_results: list[list[EntityNode]] = list( - await asyncio.gather( + await semaphore_gather( *[ node_fulltext_search(driver, query, group_ids, 2 * limit), node_similarity_search( @@ -295,7 +295,7 @@ async def community_search( return [] search_results: list[list[CommunityNode]] = list( - await asyncio.gather( + await semaphore_gather( *[ community_fulltext_search(driver, query, group_ids, 2 * limit), community_similarity_search( diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index e271dc01..42c52ab7 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -14,7 +14,6 @@ limitations under the License. """ -import asyncio import logging from collections import defaultdict from time import time @@ -30,6 +29,7 @@ USE_PARALLEL_RUNTIME, lucene_sanitize, normalize_l2, + semaphore_gather, ) from graphiti_core.nodes import ( CommunityNode, @@ -549,7 +549,7 @@ async def hybrid_node_search( start = time() results: list[list[EntityNode]] = list( - await asyncio.gather( + await semaphore_gather( *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries], *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings], ) @@ -619,7 +619,7 @@ async def get_relevant_edges( relevant_edges: list[EntityEdge] = [] relevant_edge_uuids = set() - results = await asyncio.gather( + results = await semaphore_gather( *[ edge_similarity_search( driver, diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 5deb224d..80f66029 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -14,7 +14,6 @@ limitations under the License. """ -import asyncio import logging import typing from collections import defaultdict @@ -26,6 +25,7 @@ from pydantic import BaseModel from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge +from graphiti_core.helpers import semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.models.edges.edge_db_queries import ( ENTITY_EDGE_SAVE_BULK, @@ -71,7 +71,7 @@ class RawEpisode(BaseModel): async def retrieve_previous_episodes_bulk( driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: - previous_episodes_list = await asyncio.gather( + previous_episodes_list = await semaphore_gather( *[ retrieve_episodes( driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id] @@ -118,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx( async def extract_nodes_and_edges_bulk( llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: - extracted_nodes_bulk = await asyncio.gather( + extracted_nodes_bulk = await semaphore_gather( *[ extract_nodes(llm_client, episode, previous_episodes) for episode, previous_episodes in episode_tuples @@ -130,7 +130,7 @@ async def extract_nodes_and_edges_bulk( [episode[1] for episode in episode_tuples], ) - extracted_edges_bulk = await asyncio.gather( + extracted_edges_bulk = await semaphore_gather( *[ extract_edges( llm_client, @@ -171,13 +171,13 @@ async def dedupe_nodes_bulk( node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] existing_nodes_chunks: list[list[EntityNode]] = list( - await asyncio.gather( + await semaphore_gather( *[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks] ) ) results: list[tuple[list[EntityNode], dict[str, str]]] = list( - await asyncio.gather( + await semaphore_gather( *[ dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i]) for i, node_chunk in enumerate(node_chunks) @@ -205,13 +205,13 @@ async def dedupe_edges_bulk( ] relevant_edges_chunks: list[list[EntityEdge]] = list( - await asyncio.gather( + await semaphore_gather( *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks] ) ) resolved_edge_chunks: list[list[EntityEdge]] = list( - await asyncio.gather( + await semaphore_gather( *[ dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i]) for i, edge_chunk in enumerate(edge_chunks) @@ -292,7 +292,9 @@ async def compress_nodes( # add both nodes to the shortest chunk node_chunks[-1].extend([n, m]) - results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) + results = await semaphore_gather( + *[dedupe_node_list(llm_client, chunk) for chunk in node_chunks] + ) extended_map = dict(uuid_map) compressed_nodes: list[EntityNode] = [] @@ -315,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list # We build a map of the edges based on their source and target nodes. edge_chunks = chunk_edges_by_nodes(edges) - results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) + results = await semaphore_gather( + *[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks] + ) compressed_edges: list[EntityEdge] = [] for edge_chunk in results: @@ -368,7 +372,7 @@ async def extract_edge_dates_bulk( episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs } - results = await asyncio.gather( + results = await semaphore_gather( *[ extract_edge_dates( llm_client, diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index a585c149..0db68aed 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -7,7 +7,7 @@ from graphiti_core.edges import CommunityEdge from graphiti_core.embedder import EmbedderClient -from graphiti_core.helpers import DEFAULT_DATABASE +from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import ( CommunityNode, @@ -71,7 +71,7 @@ async def get_community_clusters( community_clusters.extend( list( - await asyncio.gather( + await semaphore_gather( *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids] ) ) @@ -164,7 +164,7 @@ async def build_community( odd_one_out = summaries.pop() length -= 1 new_summaries: list[str] = list( - await asyncio.gather( + await semaphore_gather( *[ summarize_pair(llm_client, (str(left_summary), str(right_summary))) for left_summary, right_summary in zip( @@ -207,7 +207,9 @@ async def limited_build_community(cluster): return await build_community(llm_client, cluster) communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list( - await asyncio.gather(*[limited_build_community(cluster) for cluster in community_clusters]) + await semaphore_gather( + *[limited_build_community(cluster) for cluster in community_clusters] + ) ) community_nodes: list[CommunityNode] = [] diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index e3fa4f7a..b55f294c 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -14,13 +14,12 @@ limitations under the License. """ -import asyncio import logging from datetime import datetime from time import time from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge -from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS +from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library @@ -199,7 +198,7 @@ async def resolve_extracted_edges( ) -> tuple[list[EntityEdge], list[EntityEdge]]: # resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates results: list[tuple[EntityEdge, list[EntityEdge]]] = list( - await asyncio.gather( + await semaphore_gather( *[ resolve_extracted_edge( llm_client, @@ -266,7 +265,7 @@ async def resolve_extracted_edge( current_episode: EpisodicNode, previous_episodes: list[EpisodicNode], ) -> tuple[EntityEdge, list[EntityEdge]]: - resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather( + resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather( dedupe_extracted_edge(llm_client, extracted_edge, related_edges), extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes), get_edge_contradictions(llm_client, extracted_edge, existing_edges), diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 89c9fdbd..e9ceb94f 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -14,14 +14,13 @@ limitations under the License. """ -import asyncio import logging from datetime import datetime, timezone from neo4j import AsyncDriver from typing_extensions import LiteralString -from graphiti_core.helpers import DEFAULT_DATABASE +from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.nodes import EpisodeType, EpisodicNode EPISODE_WINDOW_LEN = 3 @@ -38,7 +37,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo database_=DEFAULT_DATABASE, ) index_names = [record['name'] for record in records] - await asyncio.gather( + await semaphore_gather( *[ driver.execute_query( """DROP INDEX $name""", @@ -82,7 +81,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo index_queries: list[LiteralString] = range_indices + fulltext_indices - await asyncio.gather( + await semaphore_gather( *[ driver.execute_query( query, diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 22201e22..fe450163 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -14,11 +14,10 @@ limitations under the License. """ -import asyncio import logging from time import time -from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS +from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.prompts import prompt_library @@ -223,7 +222,7 @@ async def resolve_extracted_nodes( uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] results: list[tuple[EntityNode, dict[str, str]]] = list( - await asyncio.gather( + await semaphore_gather( *[ resolve_extracted_node( llm_client, extracted_node, existing_nodes, episode, previous_episodes @@ -275,7 +274,7 @@ async def resolve_extracted_node( else [], } - llm_response, node_summary_response = await asyncio.gather( + llm_response, node_summary_response = await semaphore_gather( llm_client.generate_response( prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate ), diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 7bc82550..a35b5206 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -14,7 +14,6 @@ limitations under the License. """ -import asyncio import logging import os import sys @@ -25,6 +24,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.graphiti import Graphiti +from graphiti_core.helpers import semaphore_gather from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.search.search_config_recipes import ( COMBINED_HYBRID_SEARCH_CROSS_ENCODER, @@ -137,8 +137,8 @@ async def test_graph_integration(): edges = [episodic_edge_1, episodic_edge_2, entity_edge] # test save - await asyncio.gather(*[node.save(driver) for node in nodes]) - await asyncio.gather(*[edge.save(driver) for edge in edges]) + await semaphore_gather(*[node.save(driver) for node in nodes]) + await semaphore_gather(*[edge.save(driver) for edge in edges]) # test get assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None @@ -147,5 +147,5 @@ async def test_graph_integration(): assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None # test delete - await asyncio.gather(*[node.delete(driver) for node in nodes]) - await asyncio.gather(*[edge.delete(driver) for edge in edges]) + await semaphore_gather(*[node.delete(driver) for node in nodes]) + await semaphore_gather(*[edge.delete(driver) for edge in edges])