-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support text2gql search for GraphRAG #2227
Open
SonglinLyu
wants to merge
1
commit into
eosphoros-ai:main
Choose a base branch
from
SonglinLyu:graphrag_text2gql_dev
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""AWELIntentInterpreter class.""" | ||
import logging | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AWELIntentInterpreter(TranslatorBase): | ||
"""AWELIntentInterpreter class.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
"""IntentInterpreter class.""" | ||
import logging, re, json | ||
from typing import Dict, Optional | ||
|
||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest | ||
from dbgpt.rag.transformer.llm_translator import LLMTranslator | ||
|
||
INTENT_INTERPRET_PT = ( | ||
"A question is provided below. Given the question, analyze and classify it into one of the following categories:\n" | ||
"1. Single Entity Search: search for the detail of the given entity.\n" | ||
"2. One Hop Entity Search: given one entity and one relation, " | ||
"search for all entities that have the relation with the given entity.\n" | ||
"3. One Hop Relation Search: given two entities, serach for the relation between them.\n" | ||
"4. Two Hop Entity Search: given one entity and one relation, break that relation into two consecutive relation, " | ||
"then search all entities that have the two hop relation with the given entity.\n" | ||
"5. Freestyle Question: questions that are not in above four categories. " | ||
"Search all related entities and two-hop subgraphs centered on them.\n" | ||
"After classfied the given question, rewrite the question in a graph query language style, " | ||
"return the category of the given question, the rewrited question in json format." | ||
"Also return entities and relations that might be used for query generation in json format." | ||
"Here are some examples to guide your classification:\n" | ||
"---------------------\n" | ||
"Example:\n" | ||
"Question: Introduce TuGraph.\n" | ||
"Return:\n{{\"category\": \"Single Entity Search\", \"rewrited_question\": \"Query the entity named TuGraph then return the entity.\", " | ||
"\"entities\": [\"TuGraph\"], \"relations\": []}}\n" | ||
"Question: Who commits code to TuGraph.\n" | ||
"Return:\n{{\"category\": \"One Hop Entity Search\", \"rewrited_question\": \"Query all one hop paths that has a entity named TuGraph and a relation named commit, then return them.\", " | ||
"\"entities\": [\"TuGraph\"], \"relations\": [\"commit\"]}}\n" | ||
"Question: What is the relation between Alex and TuGraph?\n" | ||
"Return:\n{{\"category\": \"One Hop Relation Search\", \"rewrited_question\": \"Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them.\", " | ||
"\"entities\": [\"Alex\", \"TuGraph\"], \"relations\": []}}\n" | ||
"Question: Who is the colleague of Bob?\n" | ||
"Return:\n{{\"category\": \"Two Hop Entity Search\", \"rewrited_question\": \"Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity.\", " | ||
"\"entities\": [\"Bob\"], \"relations\": [\"work for\"]}}\n" | ||
"Question: Introduce TuGraph and DBGPT seperately.\n" | ||
"Return:\n{{\"category\": \"Freestyle Question\", \"rewrited_question\": \"Query the entity named TuGraph and the entity named DBGPT, then return two-hop subgraphs centered on them.\", " | ||
"\"entities\": [\"TuGraph\", \"DBGPT\"], \"relations\": []}}\n" | ||
"---------------------\n" | ||
"Text: {text}\n" | ||
"Keywords:\n" | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class IntentInterpreter(LLMTranslator): | ||
"""IntentInterpreter class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str): | ||
"""Initialize the IntentInterpreter.""" | ||
super().__init__(llm_client, model_name, INTENT_INTERPRET_PT) | ||
|
||
async def _translate( | ||
self, text: str, history: str = None, limit: Optional[int] = None, type: Optional[str] = "PROMPT" | ||
) -> Dict: | ||
"""Inner translate by LLM.""" | ||
|
||
""" | ||
The returned diction should contain the following content. | ||
{ | ||
"category": "Type of the given question.", | ||
"original_question: "The original question provided by user.", | ||
"rewrited_question": "Question that has been rewritten in graph query language style." | ||
"entities": ["entities", "that", "might", "be", "used", "in", "query"], | ||
"relations" ["relations", "that", "might", "be", "used", "in", "query"] | ||
} | ||
""" | ||
|
||
# interprete intent with single prompt only. | ||
template = HumanPromptTemplate.from_template(self._prompt_template) | ||
|
||
messages = ( | ||
template.format_messages(text=text, history=history) | ||
if history is not None | ||
else template.format_messages(text=text) | ||
) | ||
|
||
# use default model if needed | ||
if not self._model_name: | ||
models = await self._llm_client.models() | ||
if not models: | ||
raise Exception("No models available") | ||
self._model_name = models[0].model | ||
logger.info(f"Using model {self._model_name} to extract") | ||
|
||
model_messages = ModelMessage.from_base_messages(messages) | ||
request = ModelRequest(model=self._model_name, messages=model_messages) | ||
response = await self._llm_client.generate(request=request) | ||
|
||
if not response.success: | ||
code = str(response.error_code) | ||
reason = response.text | ||
logger.error(f"request llm failed ({code}) {reason}") | ||
return [] | ||
|
||
if limit and limit < 1: | ||
ValueError("optional argument limit >= 1") | ||
return self._parse_response(response.text, limit) | ||
|
||
def truncate(self): | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""Do nothing by default.""" | ||
|
||
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Parse llm response.""" | ||
intention = text | ||
|
||
code_block_pattern = re.compile(r'```json(.*?)```', re.S) | ||
json_pattern = re.compile(r'{.*?}', re.S) | ||
|
||
result = re.findall(code_block_pattern, intention) | ||
if result: | ||
intention = result[0] | ||
result = re.findall(json_pattern, intention) | ||
if result: | ||
intention = result[0] | ||
else: | ||
intention = "" | ||
|
||
return json.loads(intention) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""LLMTranslator class.""" | ||
|
||
import asyncio | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, Optional | ||
|
||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest | ||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LLMTranslator(TranslatorBase, ABC): | ||
"""LLMTranslator class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str): | ||
"""Initialize the LLMExtractor.""" | ||
self._llm_client = llm_client | ||
self._model_name = model_name | ||
self._prompt_template = prompt_template | ||
|
||
async def translate(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Translate by LLM.""" | ||
return await self._translate(text, None, limit) | ||
|
||
@abstractmethod | ||
async def _translate( | ||
self, text: str, history: str = None, limit: Optional[int] = None | ||
) -> Dict: | ||
"""Inner translate by LLM.""" | ||
|
||
def truncate(self): | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""Do nothing by default.""" | ||
|
||
@abstractmethod | ||
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Parse llm response.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""MASIntentInterpreter class.""" | ||
import logging | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MASIntentInterpreter(TranslatorBase): | ||
"""MASIntentInterpreter class.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,137 @@ | ||
"""Text2Cypher class.""" | ||
import logging | ||
import logging, re, json | ||
from typing import Dict, Optional | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest | ||
from dbgpt.rag.transformer.llm_translator import LLMTranslator | ||
from dbgpt.rag.transformer.intent_interpreter import IntentInterpreter | ||
|
||
TEXT_TO_CYPHER_PT = ( | ||
"A question written in graph query language style is provided below. " | ||
"The category of this question, " | ||
"entities and relations that might be used in the cypher query are also provided. " | ||
"Given the question, translate the question into a cypher query that " | ||
"can be executed on the given knowledge graph. " | ||
"Make sure the syntax of the translated cypher query is correct.\n" | ||
"To help query generation, the schema of the knowledge graph is:\n" | ||
"{schema}\n" | ||
"---------------------\n" | ||
"Example:\n" | ||
"Question: Query the entity named TuGraph then return the entity.\n" | ||
"Category: Single Entity Search\n" | ||
"entities: [\"TuGraph\"]\n" | ||
"relations: []\n" | ||
"Query:\nMatch (n) WHERE n.id=\"TuGraph\" RETURN n\n" | ||
"Question: Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them.\n" | ||
"Category: One Hop Entity Search\n" | ||
"entities: [\"Alex\", \"TuGraph\"]\n" | ||
"relations: []\n" | ||
"Query:\nMATCH p=(n)-[r]-(m) WHERE n.id=\"Alex\" AND m.id=\"TuGraph\" RETURN p \n" | ||
"Question: Query all one hop paths that has a entity named TuGraph and a relation named commit, then return them.\n" | ||
"Category: One Hop Relation Search\n" | ||
"entities: [\"TuGraph\"]\n" | ||
"relations: [\"commit\"]\n" | ||
"Query:\nMATCH p=(n)-[r]-(m) WHERE n.id=\"TuGraph\" AND r.id=\"commit\" RETURN p \n" | ||
"Question: Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity.\n" | ||
"Category: Two Hop Entity Search\n" | ||
"entities: [\"Bob\"]\n" | ||
"relations: [\"work for\"]\n" | ||
"Query:\nMATCH p=(n)-[r1]-(m)-[r2]-(l) WHERE n.id=\"Bob\" AND r1.id=\"work for\" AND r2.id=\"work for\" RETURN p \n" | ||
"Question: Introduce TuGraph and DBGPT seperately.\n" | ||
"Category: Freestyle Question\n" | ||
"entities: [\"TuGraph\", \"DBGPT\"]\n" | ||
"relations: []\n" | ||
"Query:\nMATCH p=(n)-[r:relation*2]-(m) WHERE n.id IN [\"TuGraph\", \"DB-GPT\"] RETURN p\n" | ||
"---------------------\n" | ||
"Question: {question}\n" | ||
"Category: {category}\n" | ||
"entities: {entities}\n" | ||
"relations: {relations}\n" | ||
"Query:\n" | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Text2Cypher(TranslatorBase): | ||
class Text2Cypher(LLMTranslator): | ||
"""Text2Cypher class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str, schema: str): | ||
"""Initialize the Text2Cypher.""" | ||
super().__init__(llm_client, model_name, TEXT_TO_CYPHER_PT) | ||
self._schema = json.dumps(json.loads(schema), indent=4) | ||
self._intent_interpreter = IntentInterpreter(llm_client, model_name) | ||
|
||
async def _translate( | ||
self, text: str, history: str = None, limit: Optional[int] = None | ||
) -> Dict: | ||
"""Inner translate by LLM.""" | ||
|
||
"""Interprete the intent of the question.""" | ||
intention = await self._intent_interpreter.translate(text) | ||
question = intention["rewrited_question"] | ||
category = intention["category"] | ||
entities = intention["entities"] | ||
relations = intention["relations"] | ||
|
||
"""Translate query with intention.""" | ||
template = HumanPromptTemplate.from_template(self._prompt_template) | ||
|
||
messages = ( | ||
template.format_messages( | ||
schema=self._schema, | ||
question=question, | ||
category=category, | ||
entities=entities, | ||
relations=relations, | ||
history=history | ||
) | ||
if history is not None | ||
else template.format_messages( | ||
schema=self._schema, | ||
question=question, | ||
category=category, | ||
entities=entities, | ||
relations=relations | ||
) | ||
) | ||
|
||
# use default model if needed | ||
if not self._model_name: | ||
models = await self._llm_client.models() | ||
if not models: | ||
raise Exception("No models available") | ||
self._model_name = models[0].model | ||
logger.info(f"Using model {self._model_name} to extract") | ||
|
||
model_messages = ModelMessage.from_base_messages(messages) | ||
request = ModelRequest(model=self._model_name, messages=model_messages) | ||
response = await self._llm_client.generate(request=request) | ||
|
||
if not response.success: | ||
code = str(response.error_code) | ||
reason = response.text | ||
logger.error(f"request llm failed ({code}) {reason}") | ||
return [] | ||
|
||
if limit and limit < 1: | ||
ValueError("optional argument limit >= 1") | ||
return self._parse_response(response.text, limit) | ||
|
||
|
||
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Parse llm response.""" | ||
interaction = {} | ||
query = "" | ||
|
||
code_block_pattern = re.compile(r'```cypher(.*?)```', re.S) | ||
|
||
result = re.findall(code_block_pattern, text) | ||
if result: | ||
query = result[0] | ||
else: | ||
query = text | ||
|
||
interaction["query"] = query.strip() | ||
|
||
return interaction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dict
needs to be import fromtyping