Skip to content

Commit

Permalink
Merge pull request #31 from Sefaria/feat/improve-graph-traversal
Browse files Browse the repository at this point in the history
fix: exclude seed chunks and renew candiate chunks in graph traversal
  • Loading branch information
Paul-Yu-Chun-Chang authored Sep 7, 2024
2 parents 545c991 + cca6f4b commit 81e9d54
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions VirtualHavruta/vh.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def traverse(json_data):
if 'primaryCategory' in sub_value and predicate(sub_value['primaryCategory']):
# Add the page_rank to each document
# PR score is initialized to 6.0 for Sefaria Linker API
sub_value['page_rank'] = 6.0
sub_value['page_rank'] = 1e3
results.append(sub_value)
elif isinstance(value, (dict, list)):
# Continue search in deeper levels
Expand Down Expand Up @@ -1209,18 +1209,21 @@ def graph_traversal_retriever(self,
total_token_count += token_count

n_accepted_chunks = 0
while n_accepted_chunks < self.config["database"]["kg"]["max_depth"]:
seed_iteration = True
while n_accepted_chunks < self.config["database"]["kg"]["max_depth"]:
top_chunk = candidate_chunks.pop(0) # Get the top chunk
collected_chunks.append(top_chunk)
local_top_score = candidate_rankings.pop(0)
ranking_scores_collected_chunks.append(local_top_score)
n_accepted_chunks += 1
if not seed_iteration:
collected_chunks.append(top_chunk)
local_top_score = candidate_rankings.pop(0)
ranking_scores_collected_chunks.append(local_top_score)
n_accepted_chunks += 1
seed_iteration = False
# avoid final loop execution which does not add a chunk to collected_chunks anyways
if n_accepted_chunks >= self.config["database"]["kg"]["max_depth"]:
break
neighbor_nodes = []
# neighbor_nodes = []
top_node = self.get_node_corresponding_to_chunk(top_chunk, msg_id=msg_id)
neighbor_nodes.append(top_node)
# neighbor_nodes.append(top_node)
neighbor_nodes_scores: list[tuple[Document, int]] = self.get_retrieval_results_knowledge_graph(
url=top_node.metadata["url"],
direction=self.config["database"]["kg"]["direction"],
Expand All @@ -1229,8 +1232,8 @@ def graph_traversal_retriever(self,
score_central_node=6.0,
msg_id=msg_id
)
neighbor_nodes += [node for node, _ in neighbor_nodes_scores]
candidate_chunks += self.get_chunks_corresponding_to_nodes(neighbor_nodes, msg_id=msg_id)
neighbor_nodes = [node for node, _ in neighbor_nodes_scores]
candidate_chunks = self.get_chunks_corresponding_to_nodes(neighbor_nodes, msg_id=msg_id)
# avoid re-adding the top chunk
candidate_chunks = [chunk for chunk in candidate_chunks if chunk not in collected_chunks]
candidate_chunks, token_count = self.select_reference(enriched_query, candidate_chunks, msg_id=msg_id)
Expand Down

0 comments on commit 81e9d54

Please sign in to comment.