From cca6f4bbe65ee767f58ab94609341ac39c6329a2 Mon Sep 17 00:00:00 2001
From: Paul Yu-Chun Chang
Date: Sat, 7 Sep 2024 13:02:18 +0000
Subject: [PATCH] fix: exclude seed chunks and renew candiate chunks in graph
traversal
---
VirtualHavruta/vh.py | 23 +++++++++++++----------
1 file changed, 13 insertions(+), 10 deletions(-)
diff --git a/VirtualHavruta/vh.py b/VirtualHavruta/vh.py
index ce32097..1173df5 100644
--- a/VirtualHavruta/vh.py
+++ b/VirtualHavruta/vh.py
@@ -982,7 +982,7 @@ def traverse(json_data):
if 'primaryCategory' in sub_value and predicate(sub_value['primaryCategory']):
# Add the page_rank to each document
# PR score is initialized to 6.0 for Sefaria Linker API
- sub_value['page_rank'] = 6.0
+ sub_value['page_rank'] = 1e3
results.append(sub_value)
elif isinstance(value, (dict, list)):
# Continue search in deeper levels
@@ -1209,18 +1209,21 @@ def graph_traversal_retriever(self,
total_token_count += token_count
n_accepted_chunks = 0
- while n_accepted_chunks < self.config["database"]["kg"]["max_depth"]:
+ seed_iteration = True
+ while n_accepted_chunks < self.config["database"]["kg"]["max_depth"]:
top_chunk = candidate_chunks.pop(0) # Get the top chunk
- collected_chunks.append(top_chunk)
- local_top_score = candidate_rankings.pop(0)
- ranking_scores_collected_chunks.append(local_top_score)
- n_accepted_chunks += 1
+ if not seed_iteration:
+ collected_chunks.append(top_chunk)
+ local_top_score = candidate_rankings.pop(0)
+ ranking_scores_collected_chunks.append(local_top_score)
+ n_accepted_chunks += 1
+ seed_iteration = False
# avoid final loop execution which does not add a chunk to collected_chunks anyways
if n_accepted_chunks >= self.config["database"]["kg"]["max_depth"]:
break
- neighbor_nodes = []
+ # neighbor_nodes = []
top_node = self.get_node_corresponding_to_chunk(top_chunk, msg_id=msg_id)
- neighbor_nodes.append(top_node)
+ # neighbor_nodes.append(top_node)
neighbor_nodes_scores: list[tuple[Document, int]] = self.get_retrieval_results_knowledge_graph(
url=top_node.metadata["url"],
direction=self.config["database"]["kg"]["direction"],
@@ -1229,8 +1232,8 @@ def graph_traversal_retriever(self,
score_central_node=6.0,
msg_id=msg_id
)
- neighbor_nodes += [node for node, _ in neighbor_nodes_scores]
- candidate_chunks += self.get_chunks_corresponding_to_nodes(neighbor_nodes, msg_id=msg_id)
+ neighbor_nodes = [node for node, _ in neighbor_nodes_scores]
+ candidate_chunks = self.get_chunks_corresponding_to_nodes(neighbor_nodes, msg_id=msg_id)
# avoid re-adding the top chunk
candidate_chunks = [chunk for chunk in candidate_chunks if chunk not in collected_chunks]
candidate_chunks, token_count = self.select_reference(enriched_query, candidate_chunks, msg_id=msg_id)