Skip to content

Commit

Permalink
feat(core): More trace records for DB-GPT (#775)
Browse files Browse the repository at this point in the history
- More trace record for DB-GPT
- Support pass span id to threadpool
  • Loading branch information
Aries-ckt authored Nov 4, 2023
2 parents a7dd328 + 59ac4ee commit 1d2b054
Show file tree
Hide file tree
Showing 17 changed files with 195 additions and 57 deletions.
1 change: 1 addition & 0 deletions assets/schema/knowledge_management.sql
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` (
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`result` TEXT NULL COMMENT 'knowledge content',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
Expand Down
69 changes: 62 additions & 7 deletions pilot/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory

Expand All @@ -38,6 +39,7 @@ class Config:

arbitrary_types_allowed = True

@trace("BaseChat.__init__")
def __init__(self, chat_param: Dict):
"""Chat Module Initialization
Args:
Expand Down Expand Up @@ -143,7 +145,14 @@ async def __call_base(self):
)
self.current_message.tokens = 0
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
metadata = {
"template_scene": self.prompt_template.template_scene,
"input_values": input_values,
}
with root_tracer.start_span(
"BaseChat.__call_base.prompt_template.format", metadata=metadata
):
current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt)

llm_messages = self.generate_llm_messages()
Expand Down Expand Up @@ -175,13 +184,25 @@ async def check_iterator_end(iterator):
except StopAsyncIteration:
return True # 迭代器已经执行结束

def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
del metadata["prompt"]
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata

async def stream_call(self):
# TODO Retry when server connection error
payload = await self.__call_base()

self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory

Expand All @@ -199,27 +220,34 @@ async def stream_call(self):
self.current_message.add_ai_message(msg)
view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg)
span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### store current conversation
span.end(metadata={"error": str(e)})
self.memory.append(self.current_message)

async def nostream_call(self):
payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory

worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()

model_output = await worker_manager.generate(payload)
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await worker_manager.generate(payload)

### output parse
ai_response_text = (
Expand All @@ -234,11 +262,18 @@ async def nostream_call(self):
ai_response_text
)
)
### run
# result = self.do_action(prompt_define_response)
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": self._parse_prompt_define_response(
prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
### run
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)

### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)
Expand All @@ -255,12 +290,14 @@ async def nostream_call(self):

view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message)
span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
span.end(metadata={"error": str(e)})
### store dialogue
self.memory.append(self.current_message)
return self.current_ai_response()
Expand Down Expand Up @@ -513,3 +550,21 @@ def generate(self, p) -> str:
"""
pass

def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
if not prompt_define_response:
return ""
if isinstance(prompt_define_response, str) or isinstance(
prompt_define_response, dict
):
return prompt_define_response
if isinstance(prompt_define_response, tuple):
if hasattr(prompt_define_response, "_asdict"):
# namedtuple
return prompt_define_response._asdict()
else:
return dict(
zip(range(len(prompt_define_response)), prompt_define_response)
)
else:
return prompt_define_response
7 changes: 6 additions & 1 deletion pilot/scene/chat_agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .prompt import prompt
from pilot.component import ComponentType
from pilot.base_modules.agent.controller import ModuleAgent
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(self, chat_param: Dict):

self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)

@trace()
async def generate_input_values(self) -> Dict[str, str]:
input_values = {
"user_goal": self.current_user_input,
Expand All @@ -63,7 +65,10 @@ async def generate_input_values(self) -> Dict[str, str]:

def stream_plugin_call(self, text):
text = text.replace("\n", " ")
return self.api_call.run(text)
with root_tracer.start_span(
"ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
):
return self.api_call.run(text)

def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
2 changes: 2 additions & 0 deletions pilot/scene/chat_dashboard/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -53,6 +54,7 @@ def __load_dashboard_template(self, template_name):
data = f.read()
return json.loads(data)

@trace()
async def generate_input_values(self) -> Dict:
try:
from pilot.summary.db_summary_client import DBSummaryClient
Expand Down
9 changes: 8 additions & 1 deletion pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.base_modules.agent.common.schema import Status
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -62,6 +63,7 @@ def _generate_numbered_list(self) -> str:
# ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))

@trace()
async def generate_input_values(self) -> Dict:
input_values = {
"user_input": self.current_user_input,
Expand All @@ -88,4 +90,9 @@ async def prepare(self):

def stream_plugin_call(self, text):
text = text.replace("\n", " ")
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
with root_tracer.start_span(
"ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
):
return self.api_call.run_display_sql(
text, self.excel_reader.get_df_by_sql_ex
)
2 changes: 2 additions & 0 deletions pilot/scene/chat_data/chat_excel/excel_learning/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder
from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
if parent_mode:
self.current_message.chat_mode = parent_mode.value()

@trace()
async def generate_input_values(self) -> Dict:
# colunms, datas = self.excel_reader.get_sample_data()
colunms, datas = await blocking_func_to_async(
Expand Down
29 changes: 19 additions & 10 deletions pilot/scene/chat_db/auto_execute/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -35,10 +36,13 @@ def __init__(self, chat_param: Dict):
raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
)

self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
with root_tracer.start_span(
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200

@trace()
async def generate_input_values(self) -> Dict:
"""
generate input values
Expand All @@ -55,13 +59,14 @@ async def generate_input_values(self) -> Dict:
# query=self.current_user_input,
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
self.db_name,
self.current_user_input,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e:
print("db summary find error!" + str(e))
if not table_infos:
Expand All @@ -80,4 +85,8 @@ async def generate_input_values(self) -> Dict:

def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run(prompt_response.sql)
with root_tracer.start_span(
"ChatWithDbAutoExecute.do_action.run_sql",
metadata=prompt_response.to_dict(),
):
return self.database.run(prompt_response.sql)
3 changes: 3 additions & 0 deletions pilot/scene/chat_db/auto_execute/out_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class SqlAction(NamedTuple):
sql: str
thoughts: Dict

def to_dict(self) -> Dict[str, Dict]:
return {"sql": self.sql, "thoughts": self.thoughts}


logger = logging.getLogger(__name__)

Expand Down
2 changes: 2 additions & 0 deletions pilot/scene/chat_db/professional_qa/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pilot.configs.config import Config
from pilot.scene.chat_db.professional_qa.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(self, chat_param: Dict):
else len(self.tables)
)

@trace()
async def generate_input_values(self) -> Dict:
table_info = ""
dialect = "mysql"
Expand Down
2 changes: 2 additions & 0 deletions pilot/scene/chat_execution/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent import PluginPromptGenerator
from .prompt import prompt
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(self, chat_param: Dict):
self.plugins_prompt_generator
)

@trace()
async def generate_input_values(self) -> Dict:
input_values = {
"input": self.current_user_input,
Expand Down
2 changes: 2 additions & 0 deletions pilot/scene/chat_knowledge/inner_db_summary/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pilot.configs.config import Config

from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
self.db_input = db_select
self.db_summary = db_summary

@trace()
async def generate_input_values(self) -> Dict:
input_values = {
"db_input": self.db_input,
Expand Down
2 changes: 2 additions & 0 deletions pilot/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.server.knowledge.service import KnowledgeService
from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -92,6 +93,7 @@ def knowledge_reference_call(self, text):
"""return reference"""
return text + f"\n\n{self.parse_source_view(self.sources)}"

@trace()
async def generate_input_values(self) -> Dict:
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
Expand Down
2 changes: 2 additions & 0 deletions pilot/scene/chat_normal/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pilot.configs.config import Config

from pilot.scene.chat_normal.prompt import prompt
from pilot.utils.tracer import root_tracer, trace

CFG = Config()

Expand All @@ -21,6 +22,7 @@ def __init__(self, chat_param: Dict):
chat_param=chat_param,
)

@trace()
async def generate_input_values(self) -> Dict:
input_values = {"input": self.current_user_input}
return input_values
Expand Down
Loading

0 comments on commit 1d2b054

Please sign in to comment.