Skip to content

Commit

Permalink
Bounded semaphore - limiting concurrency (#244)
Browse files Browse the repository at this point in the history
* WIP

* add semaphore

* remove unused imports

* remove unused imports

* lower concurrency limit
  • Loading branch information
prasmussen15 authored Dec 17, 2024
1 parent 0186ac9 commit 00fe876
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 64 deletions.
3 changes: 2 additions & 1 deletion examples/multi_session_conversation_memory/msc_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.prompts import prompt_library
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF

Expand Down Expand Up @@ -122,7 +123,7 @@ async def main():
qa_chunk = qa[i : i + 20]
group_ids = range(len(qa))[i : i + 20]
results = list(
await asyncio.gather(
await semaphore_gather(
*[
evaluate_qa(graphiti, str(group_id), query, answer)
for group_id, (query, answer) in zip(group_ids, qa_chunk)
Expand Down
3 changes: 2 additions & 1 deletion examples/multi_session_conversation_memory/msc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
parse_msc_messages,
)
from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather

load_dotenv()

Expand Down Expand Up @@ -75,7 +76,7 @@ async def main():
msc_message_slice = msc_messages[i : i + 10]
group_ids = range(len(msc_messages))[i : i + 10]

await asyncio.gather(
await semaphore_gather(
*[
add_conversation(graphiti, str(group_id), messages)
for group_id, messages in zip(group_ids, msc_message_slice)
Expand Down
4 changes: 2 additions & 2 deletions graphiti_core/cross_encoder/openai_reranker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
limitations under the License.
"""

import asyncio
import logging
from typing import Any

import openai
from openai import AsyncOpenAI
from pydantic import BaseModel

from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError
from ..prompts import Message
from .client import CrossEncoderClient
Expand Down Expand Up @@ -75,7 +75,7 @@ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]
for passage in passages
]
try:
responses = await asyncio.gather(
responses = await semaphore_gather(
*[
self.client.chat.completions.create(
model=DEFAULT_MODEL,
Expand Down
39 changes: 19 additions & 20 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import asyncio
import logging
from datetime import datetime
from time import time
Expand All @@ -27,7 +26,7 @@
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
Expand Down Expand Up @@ -340,21 +339,21 @@ async def add_episode_endpoint(episode_data: EpisodeData):

# Calculate Embeddings

await asyncio.gather(
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
)

# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
)
)

# Resolve extracted nodes with nodes already in the graph and extract facts
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
resolve_extracted_nodes(
self.llm_client,
extracted_nodes,
Expand All @@ -374,7 +373,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):
)

# calculate embeddings
await asyncio.gather(
await semaphore_gather(
*[
edge.generate_embedding(self.embedder)
for edge in extracted_edges_with_resolved_pointers
Expand All @@ -383,7 +382,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):

# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
Expand All @@ -404,7 +403,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):
)

existing_source_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
Expand All @@ -419,7 +418,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):
)

existing_target_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
Expand Down Expand Up @@ -468,7 +467,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):

# Update any communities
if update_communities:
await asyncio.gather(
await semaphore_gather(
*[
update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes
Expand Down Expand Up @@ -538,7 +537,7 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str
]

# Save all the episodes
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])

# Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
Expand All @@ -551,19 +550,19 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)

# Generate embeddings
await asyncio.gather(
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
)

# Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
)

# save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await semaphore_gather(*[node.save(self.driver) for node in nodes])

# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
Expand All @@ -574,7 +573,7 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str
)

# save episodic edges to KG
await asyncio.gather(
await semaphore_gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
)

Expand All @@ -587,7 +586,7 @@ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str
# invalidate edges

# save edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
await semaphore_gather(*[edge.save(self.driver) for edge in edges])

end = time()
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
Expand All @@ -610,12 +609,12 @@ async def build_communities(self, group_ids: list[str] | None = None) -> list[Co
self.driver, self.llm_client, group_ids
)

await asyncio.gather(
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
)

await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])

return community_nodes

Expand Down Expand Up @@ -698,7 +697,7 @@ async def _search(
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)

edges_list = await asyncio.gather(
edges_list = await semaphore_gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)

Expand Down
19 changes: 19 additions & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
limitations under the License.
"""

import asyncio
import os
from collections.abc import Coroutine
from datetime import datetime

import numpy as np
Expand All @@ -25,6 +27,7 @@

DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = 2
DEFAULT_PAGE_LIMIT = 20

Expand Down Expand Up @@ -80,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
else:
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()


# Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(
*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
):
semaphore = asyncio.Semaphore(max_coroutines)

async def _wrap_coroutine(coroutine):
async with semaphore:
return await coroutine

return await asyncio.gather(
*(_wrap_coroutine(coroutine) for coroutine in coroutines),
return_exceptions=return_exceptions,
)
10 changes: 5 additions & 5 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import asyncio
import logging
from collections import defaultdict
from time import time
Expand All @@ -25,6 +24,7 @@
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
Expand Down Expand Up @@ -78,7 +78,7 @@ async def search(

# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
edges, nodes, communities = await asyncio.gather(
edges, nodes, communities = await semaphore_gather(
edge_search(
driver,
cross_encoder,
Expand Down Expand Up @@ -141,7 +141,7 @@ async def edge_search(
return []

search_results: list[list[EntityEdge]] = list(
await asyncio.gather(
await semaphore_gather(
*[
edge_fulltext_search(driver, query, group_ids, 2 * limit),
edge_similarity_search(
Expand Down Expand Up @@ -226,7 +226,7 @@ async def node_search(
return []

search_results: list[list[EntityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[
node_fulltext_search(driver, query, group_ids, 2 * limit),
node_similarity_search(
Expand Down Expand Up @@ -295,7 +295,7 @@ async def community_search(
return []

search_results: list[list[CommunityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[
community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search(
Expand Down
6 changes: 3 additions & 3 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import asyncio
import logging
from collections import defaultdict
from time import time
Expand All @@ -30,6 +29,7 @@
USE_PARALLEL_RUNTIME,
lucene_sanitize,
normalize_l2,
semaphore_gather,
)
from graphiti_core.nodes import (
CommunityNode,
Expand Down Expand Up @@ -549,7 +549,7 @@ async def hybrid_node_search(

start = time()
results: list[list[EntityNode]] = list(
await asyncio.gather(
await semaphore_gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
)
Expand Down Expand Up @@ -619,7 +619,7 @@ async def get_relevant_edges(
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()

results = await asyncio.gather(
results = await semaphore_gather(
*[
edge_similarity_search(
driver,
Expand Down
Loading

0 comments on commit 00fe876

Please sign in to comment.