From cca6f4bbe65ee767f58ab94609341ac39c6329a2 Mon Sep 17 00:00:00 2001 From: Paul Yu-Chun Chang Date: Sat, 7 Sep 2024 13:02:18 +0000 Subject: [PATCH] fix: exclude seed chunks and renew candiate chunks in graph traversal --- VirtualHavruta/vh.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/VirtualHavruta/vh.py b/VirtualHavruta/vh.py index ce32097..1173df5 100644 --- a/VirtualHavruta/vh.py +++ b/VirtualHavruta/vh.py @@ -982,7 +982,7 @@ def traverse(json_data): if 'primaryCategory' in sub_value and predicate(sub_value['primaryCategory']): # Add the page_rank to each document # PR score is initialized to 6.0 for Sefaria Linker API - sub_value['page_rank'] = 6.0 + sub_value['page_rank'] = 1e3 results.append(sub_value) elif isinstance(value, (dict, list)): # Continue search in deeper levels @@ -1209,18 +1209,21 @@ def graph_traversal_retriever(self, total_token_count += token_count n_accepted_chunks = 0 - while n_accepted_chunks < self.config["database"]["kg"]["max_depth"]: + seed_iteration = True + while n_accepted_chunks < self.config["database"]["kg"]["max_depth"]: top_chunk = candidate_chunks.pop(0) # Get the top chunk - collected_chunks.append(top_chunk) - local_top_score = candidate_rankings.pop(0) - ranking_scores_collected_chunks.append(local_top_score) - n_accepted_chunks += 1 + if not seed_iteration: + collected_chunks.append(top_chunk) + local_top_score = candidate_rankings.pop(0) + ranking_scores_collected_chunks.append(local_top_score) + n_accepted_chunks += 1 + seed_iteration = False # avoid final loop execution which does not add a chunk to collected_chunks anyways if n_accepted_chunks >= self.config["database"]["kg"]["max_depth"]: break - neighbor_nodes = [] + # neighbor_nodes = [] top_node = self.get_node_corresponding_to_chunk(top_chunk, msg_id=msg_id) - neighbor_nodes.append(top_node) + # neighbor_nodes.append(top_node) neighbor_nodes_scores: list[tuple[Document, int]] = self.get_retrieval_results_knowledge_graph( url=top_node.metadata["url"], direction=self.config["database"]["kg"]["direction"], @@ -1229,8 +1232,8 @@ def graph_traversal_retriever(self, score_central_node=6.0, msg_id=msg_id ) - neighbor_nodes += [node for node, _ in neighbor_nodes_scores] - candidate_chunks += self.get_chunks_corresponding_to_nodes(neighbor_nodes, msg_id=msg_id) + neighbor_nodes = [node for node, _ in neighbor_nodes_scores] + candidate_chunks = self.get_chunks_corresponding_to_nodes(neighbor_nodes, msg_id=msg_id) # avoid re-adding the top chunk candidate_chunks = [chunk for chunk in candidate_chunks if chunk not in collected_chunks] candidate_chunks, token_count = self.select_reference(enriched_query, candidate_chunks, msg_id=msg_id)