Skip to content

Commit

Permalink
Merge pull request #10 from Sefaria/feat/semantic-search-local-cluster
Browse files Browse the repository at this point in the history
Feat/semantic search local cluster
  • Loading branch information
Paul-Yu-Chun-Chang authored Jul 11, 2024
2 parents d966ef5 + 4354a8a commit 9e726d1
Show file tree
Hide file tree
Showing 4 changed files with 789 additions and 69 deletions.
12 changes: 12 additions & 0 deletions VirtualHavruta/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from langchain_core.documents import Document


class ChunkDocument(Document):
def __eq__(self, other):
"""Overrides the default implementation of equal check."""
if isinstance(other, Document):
return (
self.page_content == other.page_content
and self.metadata == other.metadata
)
return False
152 changes: 152 additions & 0 deletions VirtualHavruta/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import logging
from logging import handlers
from typing import Iterable

from VirtualHavruta.document import ChunkDocument

from langchain_core.documents import Document

def create_logger(f='virtual-havruta.log', name='virtual-havruta', mb=1*1024*1024, bk=0):
logger = logging.getLogger(name)
Expand All @@ -16,3 +21,150 @@ def part_res(input_res, sep=''):
if sep:
return input_res.partition(sep)[2].strip()
return input_res.strip()

def min_max_scaling(data: Iterable, offset: float = 1e-09) -> list:
"""
Perform min-max scaling on a list or numpy array of numerical data.
Parameters:
-----------
data
The input data to be scaled.
offset
to avoid returning zero for minimum value.
Returns:
--------
The scaled data.
"""
data = list(data)
if not data:
return data

min_val = min(data)
max_val = max(data)

if min_val == max_val:
return [0.5] * len(data) # All values are the same, return 0.5

scaled_data = [(x - min_val + offset) / (max_val - min_val) for x in data]

return scaled_data


def dict_to_yaml_str(input_dict: dict, indent: int = 0) -> str:
"""
Convert a dictionary to a YAML-like string without using external libraries.
Parameters:
input_dict: The dictionary to convert.
indent: The current indentation level.
Returns:
The YAML-like string representation of the input dictionary.
"""
yaml_str = ""
for key, value in input_dict.items():
padding = " " * indent
if isinstance(value, dict):
yaml_str += f"{padding}{key}:\n{dict_to_yaml_str(value, indent + 1)}"
elif isinstance(value, list):
yaml_str += f"{padding}{key}:\n"
for item in value:
yaml_str += f"{padding}- {item}\n"
else:
yaml_str += f"{padding}{key}: {value}\n"
return yaml_str


def get_node_data(node: "Node") -> dict:
"""Given a node from the graph database, return the data of the node.
Parameters
----------
node
from the graph database
Returns
-------
data of the node
"""
try:
record = node.data()
except AttributeError:
record = node._properties
else:
assert len(record) == 1
record: dict = next(iter(record.values()))
return record


def get_id_vectordb_format(node_id: str) -> int:
"""Given a node id from the graph database, return the corresponding document seq_num in the vectordb.
Parameters
----------
node_id
id of node
Returns
-------
seq_num of the document in the vectordb
Raises
------
"""
return int(node_id) + 1


def convert_node_to_doc(node: "Node", base_url: str= "https://www.sefaria.org/") -> Document:
"""
Convert a node from the graph database to a Document object.
Parameters:
node (Node): The node from the graph database.
Returns:
Document: The Document object created from the node.
"""
node_data: dict = get_node_data(node)
metadata = {k.replace("metadata.", ""):v for k, v in node_data.items() if k.startswith("metadata.")}
metadata['URL'] = metadata['url']
del metadata['url']
metadata["seq_num"] = get_id_vectordb_format(node_data["id"])
new_reference_part = metadata["URL"].replace(base_url, "")
new_category = metadata["docCategory"]
metadata["source"] = f"Reference: {new_reference_part}. Version Title: -, Document Category: {new_category}, URL: {metadata['URL']}"

page_content = dict_to_yaml_str(node_data.get("text")) if isinstance(node_data.get("text"), dict) else node_data.get("text", "")
return ChunkDocument(
page_content=page_content,
metadata=metadata
)


def convert_vector_db_record_to_doc(record) -> ChunkDocument:
assert len(record) == 1
record: dict = next(iter(record.values()))
page_content = record.pop("text", None)
return ChunkDocument(
page_content=dict_to_yaml_str(page_content)
if isinstance(page_content, dict)
else page_content,
metadata=record
)


def get_id_graph_format(document_seq_num: int) -> str:
"""Given a document seq_num from the vectordb, return the id of the corresponding node in the graph database.
Parameters
----------
document_seq_num
from the vectordb document
Returns
-------
id of the document in the graph database
"""
return str(document_seq_num -1)
Loading

0 comments on commit 9e726d1

Please sign in to comment.