Skip to content

Commit

Permalink
Merge pull request #19 from Sefaria/feat/metadata-filtering
Browse files Browse the repository at this point in the history
feat/metadata filtering
  • Loading branch information
Paul-Yu-Chun-Chang authored Aug 12, 2024
2 parents 601e4fa + 14ca816 commit 6a73338
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 4 deletions.
138 changes: 137 additions & 1 deletion VirtualHavruta/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import logging
from typing import Iterable
import json, re

from VirtualHavruta.document import ChunkDocument

Expand Down Expand Up @@ -127,4 +128,139 @@ def convert_vector_db_record_to_doc(record) -> ChunkDocument:
if isinstance(page_content, dict)
else page_content,
metadata=record
)
)

def load_selected_keys(file_path: str, selected_keys: list) -> dict:
"""
Loads specific keys from a JSON file and returns them as a dictionary.
This function reads a JSON file from the given file path and filters the data
to include only the keys specified in the `selected_keys` list. It then returns
a dictionary containing these key-value pairs.
Parameters:
file_path (str): The path to the JSON file to be loaded.
selected_keys (list): A list of keys to be extracted from the JSON file.
Returns:
dict: A dictionary containing only the key-value pairs corresponding to the
specified `selected_keys`. If a key is not found in the JSON data, it is
ignored.
Example:
If the JSON file contains:
{
"name": "John",
"age": 30,
"city": "New York"
}
And `selected_keys` is ["name", "city"], the function will return:
{
"name": "John",
"city": "New York"
}
"""
# Open and load the JSON file
with open(file_path, 'r') as file:
data = json.load(file)

# Create a new dictionary with only the selected keys
filtered_data = {key: data[key] for key in selected_keys if key in data}

return filtered_data

def find_matched_filters(query: str, metadata_ranges: dict) -> dict:
"""
Scans the input query to find matches for strings in the provided JSON data.
Only the longest, non-overlapping matches are returned for each key.
Parameters:
query (str): The input string to be searched.
metadata_ranges (dict): A dictionary where keys map to lists of strings.
The function will check if any strings from these lists are present in the query.
Returns:
dict: A dictionary where each key corresponds to those in the input metadata_ranges
and the values are lists of matched strings found in the query.
Example:
metadata_ranges = {
"key1": ["foo", "bar", "multi word", "multi", ""],
"key2": ["another phrase", "phrase", "baz", None, "ample"]
}
query = "This is an example with multi word and foo inside another phrase."
The result will be:
{
"key1": ["multi word", "foo"],
"key2": ["another phrase"]
}
"""
# Dictionary to store matches
matched_filters = {}

# Iterate over the keys and lists in the JSON data
for key, string_list in metadata_ranges.items():
matched_strings = []
# Sort the list by length of the strings in descending order
sorted_string_list = sorted([s for s in string_list if s and s.strip()], key=len, reverse=True)
# Track the parts of the query that have already been matched
matched_query = query

# Iterate over each string in the list
for s in sorted_string_list:
# Check if the string (one-word or multi-word) is found in the query
if s and s.strip():
# Use regex to match whole words
pattern = r'\b' + re.escape(s.lower()) + r'\b'
if re.search(pattern, matched_query):
matched_strings.append(s)
# Replace the matched portion with a placeholder to avoid overlapping matches
matched_query = re.sub(pattern, '', matched_query, count=1)

# If any matches are found, add them to the matches dictionary
if matched_strings:
matched_filters[key] = matched_strings

return matched_filters

def construct_db_filter(matched_filters: dict) -> dict:
"""
Constructs a database filter string based on the given matched_filters.
Parameters:
matched_filters (dict): A dictionary containing optional keys 'primaryDocCategory' and
'authorNames', whose values are lists of strings.
Returns:
dict: A dictionary representing the DB filter string for querying.
Example:
If matched_filters is {'authorNames': ['Rashi']}, the function will return:
{"authorNames": {"$in": ['Rashi']}}
If matched_filters is {'primaryDocCategory': ['String A', 'String B'], 'authorNames': ['String C']},
the function will return:
{"$or": [{"primaryDocCategory": {"$in": ['String A', 'String B']}},
{"authorNames": {"$in": ['String C']}}]}
"""

filter_conditions = []

# Check if 'primaryDocCategory' exists and is a non-empty list
if 'primaryDocCategory' in matched_filters and matched_filters['primaryDocCategory']:
filter_conditions.append({"primaryDocCategory": {"$in": matched_filters['primaryDocCategory']}})

# Check if 'authorNames' exists and is a non-empty list
if 'authorNames' in matched_filters and matched_filters['authorNames']:
filter_conditions.append({"authorNames": {"$in": matched_filters['authorNames']}})

# If there are multiple conditions, use the $or operator
if len(filter_conditions) > 1:
return {"$or": filter_conditions}
elif filter_conditions:
return filter_conditions[0] # Return the single condition without $or
else:
return {} # Return an empty dict if no conditions are provided
29 changes: 26 additions & 3 deletions VirtualHavruta/vh.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import neo4j

from VirtualHavruta.util import convert_node_to_doc, convert_vector_db_record_to_doc, \
get_node_data, min_max_scaling
min_max_scaling


# Main Virtual Havruta functionalities
Expand Down Expand Up @@ -344,10 +344,10 @@ def retrieve_docs(self, query: str, msg_id: str = '', filter_mode: str='primary'
Raises:
ValueError: If an invalid filter_mode is provided, an exception is raised to indicate the error.
'''
self.logger.info(f"MsgID={msg_id}. [RETRIEVAL] Retrieving {filter_mode} references using this query: {query}")
self.logger.info(f"MsgID={msg_id}. [RETRIEVAL] Simple semantic search at work. Retrieving {filter_mode} references using this query: {query}")
# Convert primary_source_filter to a set for efficient lookup
retrieved_docs = self.neo4j_vector.similarity_search_with_relevance_scores(
query.lower(), self.top_k,
query, self.top_k,
)
# Filter the documents based on whether we're looking for primary or secondary sources
if filter_mode == 'primary':
Expand All @@ -358,6 +358,29 @@ def retrieve_docs(self, query: str, msg_id: str = '', filter_mode: str='primary'
raise ValueError(f"MsgID={msg_id}. Invalid filter_mode: {filter_mode}")
retrieval_res = list(filter(predicate, retrieved_docs))
return retrieval_res

def retrieve_docs_metadata_filtering(self, query: str, msg_id: str = '', metadata_fiter: dict|None=None):
'''
Retrieves documents that match a specified query and filters them based on their metadata, using a similarity search.
This function performs a similarity search based on the provided query and retrieves documents that match the metadata conditions as defined by a metadata_filter.
The results are filtered by applying the metadata filters during semantic search.
The function logs the process to ensure transparency.
Parameters:
query (str): The query string used to search for relevant documents.
msg_id (str, optional): A message identifier used for logging purposes; defaults to an empty string.
metadata_fiter (dict): The metadata filter dictionary used to filter the search results during semantic search.
Returns:
list: A list of documents that meet the criteria of the specified metadata filter.
'''
self.logger.info(f"MsgID={msg_id}. [RETRIEVAL] Metadata filtering at work. Retrieving references using this query: {query} and this metadata filter {metadata_fiter}")
# Convert primary_source_filter to a set for efficient lookup
retrieved_res = self.neo4j_vector.similarity_search_with_relevance_scores(
query, self.top_k, filter=metadata_fiter
)
return retrieved_res

def retrieve_nodes_matching_linker_results(self, linker_results: list[dict], msg_id: str = '', filter_mode: str = 'primary',
url_prefix: str = "https://www.sefaria.org/") -> list[Document]:
Expand Down

0 comments on commit 6a73338

Please sign in to comment.