Skip to content

Commit

Permalink
Merge pull request #43 from Sefaria/feat/config-for-ref-categories
Browse files Browse the repository at this point in the history
feat: config for ref categories
  • Loading branch information
Paul-Yu-Chun-Chang authored Sep 28, 2024
2 parents 9c2b3c1 + a9c0d83 commit 04c3f59
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 37 deletions.
72 changes: 37 additions & 35 deletions VirtualHavruta/vh.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,21 @@ def __init__(self, prompts_file: str, config_file: str, logger):
with open(config_file, 'r') as f:
self.config = yaml.safe_load(f)

# Initialize Neo4j vector index and retrieve DB configs
model_api = self.config['openai_model_api']
config_emb_db = self.config['database']['embed']
# Retrieve model and DB configs
self.model_api = self.config['openai_model_api']
self.chain_setups = self.config['llm_chain_setups']
self.config_emb_db = self.config['database']['embed']
self.config_kg_db = self.config['database']['kg']

# Initialize Neo4j vector index
self.neo4j_vector = Neo4jVector.from_existing_index(
OpenAIEmbeddings(model=model_api['embedding_model']),
OpenAIEmbeddings(model=self.model_api['embedding_model']),
index_name="index",
url=config_emb_db['url'],
username=config_emb_db['username'],
password=config_emb_db['password'],
url=self.config_emb_db['url'],
username=self.config_emb_db['username'],
password=self.config_emb_db['password'],
)
self.top_k = config_emb_db['top_k']
self.neo4j_deeplink = self.config['database']['kg']['neo4j_deeplink']
self.top_k = self.config_emb_db['top_k']

# Initiate logger
self.logger = logger
Expand All @@ -90,6 +93,7 @@ def __init__(self, prompts_file: str, config_file: str, logger):
self.num_primary_citations_linker = linker_references['num_primary_citations']
self.num_secondary_citations_linker = linker_references['num_secondary_citations']
self.linker_primary_source_filter = linker_references['primary_source_filter']
self.neo4j_deeplink = self.config_kg_db['neo4j_deeplink']

# Initialize prompt templates and LLM instances
self.initialize_prompt_templates()
Expand All @@ -104,8 +108,8 @@ def initialize_prompt_templates(self):
Additionally, it creates a separate prompt template for the QA (question-answering) category, including reference data.
'''
no_ref_categories = ['anti_attack', 'adaptor', 'editor', 'optimization']
ref_categories = ['classification', 'qa', 'selector']
no_ref_categories = self.chain_setups['no_ref_chains']
ref_categories = self.chain_setups['ref_chains']
no_ref_prompts = {'prompt_'+cat: self.create_prompt_template('system', cat) for cat in no_ref_categories}
ref_prompts = {'prompt_'+cat: self.create_prompt_template('system', cat, True) for cat in ref_categories}
self.__dict__.update(no_ref_prompts)
Expand Down Expand Up @@ -150,19 +154,17 @@ def initialize_llm_instances(self):
Returns: None
'''
model_api = self.config['openai_model_api']
chain_setups = self.config['llm_chain_setups']

# Adding a condition to include json kwargs for models ending with '_json'
for model_name, suffixes in chain_setups.items():
model_kwargs = {"response_format": {"type": "json_object"}} if model_name.endswith('_json') else {}
model_key = model_name.replace('_json', '') # Removes the '_json' suffix for lookup in model_api
setattr(self, model_name, ChatOpenAI(
temperature=model_api.get(f"{model_key}_temperature", None),
model=model_api.get(model_key, None),
model_kwargs=model_kwargs
))
self.initialize_llm_chains(getattr(self, model_name), suffixes)
for model_name, suffixes in self.chain_setups.items():
if model_name.startswith(('main', 'support')):
model_kwargs = {"response_format": {"type": "json_object"}} if model_name.endswith('_json') else {}
model_key = model_name.replace('_json', '') # Removes the '_json' suffix for lookup in model_api
setattr(self, model_name, ChatOpenAI(
temperature=self.model_api.get(f"{model_key}_temperature", None),
model=self.model_api.get(model_key, None),
model_kwargs=model_kwargs
))
self.initialize_llm_chains(getattr(self, model_name), suffixes)

def initialize_llm_chains(self, model, suffixes):
'''
Expand Down Expand Up @@ -474,8 +476,8 @@ def get_retrieval_results_knowledge_graph(self, url: str, direction: str, order:
Example:
res = vh.get_retrieval_results_knowledge_graph(
url=top_node.metadata["url"],
direction=self.config["database"]["kg"]["direction"],
order=self.config["database"]["kg"]["order"],
direction=self.config_kg_db["direction"],
order=self.config_kg_db["order"],
filter_mode_nodes=filter_mode_nodes,
score_central_node=6.0,
msg_id=msg_id
Expand Down Expand Up @@ -557,11 +559,11 @@ def get_graph_neighbors_by_url(self, url: str, relationship: str, depth: int, fi
{source_filter}
RETURN DISTINCT neighbor, {i} AS depth
"""
with neo4j.GraphDatabase.driver(self.config["database"]["kg"]["url"], auth=(self.config["database"]["kg"]["username"], self.config["database"]["kg"]["password"])) as driver:
with neo4j.GraphDatabase.driver(self.config_kg_db["url"], auth=(self.config_kg_db["username"], self.config_kg_db["password"])) as driver:
neighbor_nodes, _, _ = driver.execute_query(
query,
parameters_=query_params,
database_=self.config["database"]["kg"]["name"],)
database_=self.config_kg_db["name"],)
nodes.extend(neighbor_nodes)
self.logger.info(f"MsgID={msg_id}. [GRAGH NEIGHBOR RETRIEVAL] Retrieved {len(nodes)} graph neighbors.")
return nodes
Expand Down Expand Up @@ -592,11 +594,11 @@ def query_graph_db_by_url(self, urls: list[str]) -> list[Document]:
WHERE any(substring IN $urls WHERE n.url = substring)
RETURN n
"""
with neo4j.GraphDatabase.driver(self.config["database"]["kg"]["url"], auth=(self.config["database"]["kg"]["username"], self.config["database"]["kg"]["password"])) as driver:
with neo4j.GraphDatabase.driver(self.config_kg_db["url"], auth=(self.config_kg_db["username"], self.config_kg_db["password"])) as driver:
nodes, _, _ = driver.execute_query(
query_string,
parameters_=query_parameters,
database_=self.config["database"]["kg"]["name"],)
database_=self.config_kg_db["name"],)
return [convert_node_to_doc(node) for node in nodes]

def select_reference(self, query: str, retrieval_res, msg_id: str = ''):
Expand Down Expand Up @@ -1358,7 +1360,7 @@ def graph_traversal_retriever(self,
n_accepted_chunks = 0
n_iter = 0
seed_iteration = True
while n_accepted_chunks < self.config["database"]["kg"]["max_depth"]:
while n_accepted_chunks < self.config_kg_db["max_depth"]:
if len(candidate_chunks) == 0:
break
# Get the top chunk
Expand All @@ -1369,7 +1371,7 @@ def graph_traversal_retriever(self,
ranking_scores_collected_chunks.append(local_top_score)
n_accepted_chunks += 1
# avoid final loop execution which does not add a chunk to collected_chunks anyways
if n_accepted_chunks >= self.config["database"]["kg"]["max_depth"]:
if n_accepted_chunks >= self.config_kg_db["max_depth"]:
break
else:
n_iter +=1
Expand All @@ -1378,8 +1380,8 @@ def graph_traversal_retriever(self,
top_node = self.get_node_corresponding_to_chunk(top_chunk, msg_id=msg_id)
neighbor_nodes_scores: list[tuple[Document, int]] = self.get_retrieval_results_knowledge_graph(
url=top_node.metadata["url"],
direction=self.config["database"]["kg"]["direction"],
order=self.config["database"]["kg"]["order"],
direction=self.config_kg_db["direction"],
order=self.config_kg_db["order"],
filter_mode_nodes=filter_mode_nodes,
score_central_node=6.0,
msg_id=msg_id
Expand Down Expand Up @@ -1697,11 +1699,11 @@ def get_node_corresponding_to_chunk(self, chunk: Document, msg_id: str = '') ->
AND n.versionTitle=$versionTitle
RETURN n
"""
with neo4j.GraphDatabase.driver(self.config["database"]["kg"]["url"], auth=(self.config["database"]["kg"]["username"], self.config["database"]["kg"]["password"])) as driver:
with neo4j.GraphDatabase.driver(self.config_kg_db["url"], auth=(self.config_kg_db["username"], self.config_kg_db["password"])) as driver:
nodes, _, _ = driver.execute_query(
query_string,
parameters_=query_parameters,
database_=self.config["database"]["kg"] ["name"],)
database_=self.config_kg_db["name"],)
assert len(nodes) == 1
node = nodes[0]
self.logger.info(f"MsgID={msg_id}. [CHUNK2NODE] Found chunk-corresponding node for {query_parameters}")
Expand Down
6 changes: 4 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ openai_model_api:
# LLM Chain Setups
llm_chain_setups:
main_model: ['chain1', 'chain2']
main_model_json: ['chain3']
main_model_json: ['chain3'] #LLM chains that require a json output format
support_model: ['chain4', 'chain5', 'chain6']
support_model_json: []
support_model_json: [] #LLM chains that require a json output format
ref_chains: ['chain1', 'chain3', 'chain4', 'chain5'] #LLM chains that require reference data as part of the input. These should correspond to those chains in main_model(_json) and support_model(_json).
no_ref_chains: ['chain2', 'chain6'] #LLM chains that don't require reference data as part of the input. These should correspond to those chains in main_model(_json) and support_model(_json).

# Reference Settings
references:
Expand Down

0 comments on commit 04c3f59

Please sign in to comment.