Skip to content

Commit

Permalink
Cross encoder reranker in search query (#202)
Browse files Browse the repository at this point in the history
* cross encoder reranker

* update reranker

* add openai reranker

* format

* mypy

* update

* updates

* MyPy typing

* bump version
  • Loading branch information
prasmussen15 authored Oct 25, 2024
1 parent 544f9e3 commit ceb60a3
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 191 deletions.
113 changes: 113 additions & 0 deletions graphiti_core/cross_encoder/openai_reranker_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import logging
from typing import Any

import openai
from openai import AsyncOpenAI
from pydantic import BaseModel

from ..llm_client import LLMConfig, RateLimitError
from ..prompts import Message
from .client import CrossEncoderClient

logger = logging.getLogger(__name__)

DEFAULT_MODEL = 'gpt-4o-mini'


class BooleanClassifier(BaseModel):
isTrue: bool


class OpenAIRerankerClient(CrossEncoderClient):
def __init__(self, config: LLMConfig | None = None):
"""
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()

self.config = config
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)

async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
openai_messages_list: Any = [
[
Message(
role='system',
content='You are an expert tasked with determining whether the passage is relevant to the query',
),
Message(
role='user',
content=f"""
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
<PASSAGE>
{query}
</PASSAGE>
{passage}
<QUERY>
</QUERY>
""",
),
]
for passage in passages
]
try:
responses = await asyncio.gather(
*[
self.client.chat.completions.create(
model=DEFAULT_MODEL,
messages=openai_messages,
temperature=0,
max_tokens=1,
logit_bias={'6432': 1, '7983': 1},
logprobs=True,
top_logprobs=2,
)
for openai_messages in openai_messages_list
]
)

responses_top_logprobs = [
response.choices[0].logprobs.content[0].top_logprobs
if response.choices[0].logprobs is not None
and response.choices[0].logprobs.content is not None
else []
for response in responses
]
scores: list[float] = []
for top_logprobs in responses_top_logprobs:
for logprob in top_logprobs:
if bool(logprob.token):
scores.append(logprob.logprob)

results = [(passage, score) for passage, score in zip(passages, scores)]
results.sort(reverse=True, key=lambda x: x[1])
return results
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
31 changes: 28 additions & 3 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
from neo4j import AsyncGraphDatabase
from pydantic import BaseModel

from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
Expand Down Expand Up @@ -92,6 +95,7 @@ def __init__(
password: str,
llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None,
cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True,
):
"""
Expand Down Expand Up @@ -131,7 +135,7 @@ def __init__(
Graphiti if you're using the default OpenAIClient.
"""
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = 'neo4j'
self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content
if llm_client:
self.llm_client = llm_client
Expand All @@ -141,6 +145,10 @@ def __init__(
self.embedder = embedder
else:
self.embedder = OpenAIEmbedder()
if cross_encoder:
self.cross_encoder = cross_encoder
else:
self.cross_encoder = OpenAIRerankerClient()

async def close(self):
"""
Expand Down Expand Up @@ -648,6 +656,7 @@ async def search(
await search(
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
search_config,
Expand All @@ -663,8 +672,18 @@ async def _search(
config: SearchConfig,
group_ids: list[str] | None = None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
return await search(
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
config,
center_node_uuid,
bfs_origin_node_uuids,
)

async def get_nodes_by_query(
self,
Expand Down Expand Up @@ -716,7 +735,13 @@ async def get_nodes_by_query(

nodes = (
await search(
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
search_config,
center_node_uuid,
)
).nodes
return nodes
Expand Down
33 changes: 33 additions & 0 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from neo4j import AsyncDriver

from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
Expand All @@ -39,6 +40,7 @@
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
Expand All @@ -55,10 +57,12 @@
async def search(
driver: AsyncDriver,
embedder: EmbedderClient,
cross_encoder: CrossEncoderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
start = time()
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
Expand All @@ -68,28 +72,34 @@ async def search(
edges, nodes, communities = await asyncio.gather(
edge_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.edge_config,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
),
node_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.node_config,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
),
community_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.community_config,
bfs_origin_node_uuids,
config.limit,
),
)
Expand All @@ -109,11 +119,13 @@ async def search(

async def edge_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
if config is None:
Expand All @@ -126,6 +138,7 @@ async def edge_search(
edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
]
)
)
Expand All @@ -146,6 +159,10 @@ async def edge_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == EdgeReranker.cross_encoder:
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
Expand Down Expand Up @@ -176,11 +193,13 @@ async def edge_search(

async def node_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
if config is None:
Expand Down Expand Up @@ -212,6 +231,12 @@ async def node_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == NodeReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
Expand All @@ -228,10 +253,12 @@ async def node_search(

async def community_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
if config is None:
Expand Down Expand Up @@ -268,6 +295,12 @@ async def community_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == CommunityReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]

reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

Expand Down
Loading

0 comments on commit ceb60a3

Please sign in to comment.