Skip to content

Commit

Permalink
refactor(ChatData):update rdbms db summary (#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt authored Dec 4, 2023
1 parent b12a858 commit a2f087c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 333 deletions.
67 changes: 0 additions & 67 deletions pilot/connections/manages/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,73 +60,6 @@ def __init__(self, system_app: SystemApp):
# self.storage = DuckdbConnectConfig()
self.storage = ConnectConfigDao()
self.db_summary_client = DBSummaryClient(system_app)
# self.__load_config_db()

# def __load_config_db(self):
# if CFG.LOCAL_DB_HOST:
# # default mysql
# if CFG.LOCAL_DB_NAME:
# self.storage.add_url_db(
# CFG.LOCAL_DB_NAME,
# DBType.Mysql.value(),
# CFG.LOCAL_DB_HOST,
# CFG.LOCAL_DB_PORT,
# CFG.LOCAL_DB_USER,
# CFG.LOCAL_DB_PASSWORD,
# "",
# )
# else:
# # get all default mysql database
# default_mysql = Database.from_uri(
# "mysql+pymysql://"
# + CFG.LOCAL_DB_USER
# + ":"
# + CFG.LOCAL_DB_PASSWORD
# + "@"
# + CFG.LOCAL_DB_HOST
# + ":"
# + str(CFG.LOCAL_DB_PORT),
# engine_args={
# "pool_size": CFG.LOCAL_DB_POOL_SIZE,
# "pool_recycle": 3600,
# "echo": True,
# },
# )
# dbs = default_mysql.get_database_list()
# for name in dbs:
# self.storage.add_url_db(
# name,
# DBType.Mysql.value(),
# CFG.LOCAL_DB_HOST,
# CFG.LOCAL_DB_PORT,
# CFG.LOCAL_DB_USER,
# CFG.LOCAL_DB_PASSWORD,
# "",
# )
# db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
# if db_type.is_file_db():
# db_name = CFG.LOCAL_DB_NAME
# db_type = CFG.LOCAL_DB_TYPE
# db_path = CFG.LOCAL_DB_PATH
# if not db_type:
# # Default file database type
# db_type = DBType.DuckDb.value()
# if not db_name:
# db_type, db_name = self._parse_file_db_info(db_type, db_path)
# if db_name:
# print(
# f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}"
# )
# self.storage.add_file_db(db_name, db_type, db_path)

# def _parse_file_db_info(self, db_type: str, db_path: str):
# if db_type is None or db_type == DBType.DuckDb.value():
# # file db is duckdb
# db_name = self.storage.get_file_db_name(db_path)
# db_type = DBType.DuckDb.value()
# else:
# db_name = DBType.parse_file_db_name_from_path(db_type, db_path)
# return db_type, db_name

def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name)
Expand Down
118 changes: 8 additions & 110 deletions pilot/summary/db_summary_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
import uuid
import logging

from pilot.common.schema import DBType
from pilot.component import SystemApp
from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
EMBEDDING_MODEL_CONFIG,
)

from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.scene.chat_factory import ChatFactory
from pilot.summary.rdbms_db_summary import RdbmsSummary

Expand All @@ -33,7 +29,6 @@ def __init__(self, system_app: SystemApp):

def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.embedding_engine.embedding_factory import EmbeddingFactory

db_summary_client = RdbmsSummary(dbname, db_type)
Expand All @@ -43,48 +38,12 @@ def db_summary_embedding(self, dbname, db_type):
embeddings = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summary():
embedding = StringEmbedding(
vector_table_info,
vector_store_config,
)
embedding.source_embedding()
else:
embedding = StringEmbedding(
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
embedding.source_embedding()
for (
table_name,
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
embedding = StringEmbedding(
table_summary,
table_vector_store_config,
)
embedding.source_embedding()

logger.info("db summary embedding success")

def get_db_summary(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory

Expand All @@ -104,53 +63,8 @@ def get_db_summary(self, dbname, query, topk):
ans = [d.page_content for d in table_docs]
return ans

def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory

vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
related_tables = [
json.loads(table_doc.page_content)["table_name"]
for table_doc in table_docs
]
else:
table_docs = knowledge_embedding_client.similar_search(query, 1)
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
# query, table_docs[0].page_content
# )
related_tables = _get_llm_response(
query, dbname, table_docs[0].page_content
)
related_table_summaries = []
for table in related_tables:
vector_store_config = {
"vector_store_name": dbname + "_" + table + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
table_summary = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summary[0].page_content)
return related_table_summaries

def init_db_summary(self):
"""init db summary"""
db_mange = CFG.LOCAL_DB_MANAGE
dbs = db_mange.get_db_list()
for item in dbs:
Expand All @@ -177,41 +91,25 @@ def init_db_profile(self, db_summary_client, dbname, embeddings):
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_db_summary(),
file_path=None,
vector_store_config=profile_store_config,
)
if not embedding.vector_name_exist():
docs = []
docs.extend(embedding.read_batch())
for table_summary in db_summary_client.table_info_json():
for table_summary in db_summary_client.table_summaries():
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
chunk_size=len(table_summary), chunk_overlap=100
chunk_size=len(table_summary), chunk_overlap=0
)
embedding = StringEmbedding(
file_path=table_summary,
vector_store_config=profile_store_config,
text_splitter=text_splitter,
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)
if len(docs) > 0:
embedding.index_to_store(docs)
else:
logger.info(f"Vector store name {vector_store_name} exist")
logger.info("init db profile success...")


def _get_llm_response(query, db_input, dbsummary):
chat_param = {
"temperature": 0.7,
"max_new_tokens": 512,
"chat_session_id": uuid.uuid1(),
"user_input": query,
"db_select": db_input,
"db_summary": dbsummary,
}
chat: BaseChat = chat_factory.get_implementation(
ChatScene.InnerChatDBSummary.value, **chat_param
)
res = chat._blocking_nostream_call()
return json.loads(res)["table"]
logger.info("initialize db summary profile success...")
Loading

0 comments on commit a2f087c

Please sign in to comment.