diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql index e38f731d6..a6f1bc478 100644 --- a/assets/schema/knowledge_management.sql +++ b/assets/schema/knowledge_management.sql @@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` ( `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', `result` TEXT NULL COMMENT 'knowledge content', `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', PRIMARY KEY (`id`), diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 34f294c31..4529d3cb4 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -13,6 +13,7 @@ from pilot.scene.message import OnceConversation from pilot.utils import get_or_create_event_loop from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace from pydantic import Extra from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory @@ -38,6 +39,7 @@ class Config: arbitrary_types_allowed = True + @trace("BaseChat.__init__") def __init__(self, chat_param: Dict): """Chat Module Initialization Args: @@ -143,7 +145,14 @@ async def __call_base(self): ) self.current_message.tokens = 0 if self.prompt_template.template: - current_prompt = self.prompt_template.format(**input_values) + metadata = { + "template_scene": self.prompt_template.template_scene, + "input_values": input_values, + } + with root_tracer.start_span( + "BaseChat.__call_base.prompt_template.format", metadata=metadata + ): + current_prompt = self.prompt_template.format(**input_values) self.current_message.add_system_message(current_prompt) llm_messages = self.generate_llm_messages() @@ -175,6 +184,14 @@ async def check_iterator_end(iterator): except StopAsyncIteration: return True # 迭代器已经执行结束 + def _get_span_metadata(self, payload: Dict) -> Dict: + metadata = {k: v for k, v in payload.items()} + del metadata["prompt"] + metadata["messages"] = list( + map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"]) + ) + return metadata + async def stream_call(self): # TODO Retry when server connection error payload = await self.__call_base() @@ -182,6 +199,10 @@ async def stream_call(self): self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") ai_response_text = "" + span = root_tracer.start_span( + "BaseChat.stream_call", metadata=self._get_span_metadata(payload) + ) + payload["span_id"] = span.span_id try: from pilot.model.cluster import WorkerManagerFactory @@ -199,6 +220,7 @@ async def stream_call(self): self.current_message.add_ai_message(msg) view_msg = self.knowledge_reference_call(msg) self.current_message.add_view_message(view_msg) + span.end() except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) @@ -206,12 +228,17 @@ async def stream_call(self): f"""ERROR!{str(e)}\n {ai_response_text} """ ) ### store current conversation + span.end(metadata={"error": str(e)}) self.memory.append(self.current_message) async def nostream_call(self): payload = await self.__call_base() logger.info(f"Request: \n{payload}") ai_response_text = "" + span = root_tracer.start_span( + "BaseChat.nostream_call", metadata=self._get_span_metadata(payload) + ) + payload["span_id"] = span.span_id try: from pilot.model.cluster import WorkerManagerFactory @@ -219,7 +246,8 @@ async def nostream_call(self): ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() - model_output = await worker_manager.generate(payload) + with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): + model_output = await worker_manager.generate(payload) ### output parse ai_response_text = ( @@ -234,11 +262,18 @@ async def nostream_call(self): ai_response_text ) ) - ### run - # result = self.do_action(prompt_define_response) - result = await blocking_func_to_async( - self._executor, self.do_action, prompt_define_response - ) + metadata = { + "model_output": model_output.to_dict(), + "ai_response_text": ai_response_text, + "prompt_define_response": self._parse_prompt_define_response( + prompt_define_response + ), + } + with root_tracer.start_span("BaseChat.do_action", metadata=metadata): + ### run + result = await blocking_func_to_async( + self._executor, self.do_action, prompt_define_response + ) ### llm speaker speak_to_user = self.get_llm_speak(prompt_define_response) @@ -255,12 +290,14 @@ async def nostream_call(self): view_message = view_message.replace("\n", "\\n") self.current_message.add_view_message(view_message) + span.end() except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) + span.end(metadata={"error": str(e)}) ### store dialogue self.memory.append(self.current_message) return self.current_ai_response() @@ -513,3 +550,21 @@ def generate(self, p) -> str: """ pass + + def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any: + if not prompt_define_response: + return "" + if isinstance(prompt_define_response, str) or isinstance( + prompt_define_response, dict + ): + return prompt_define_response + if isinstance(prompt_define_response, tuple): + if hasattr(prompt_define_response, "_asdict"): + # namedtuple + return prompt_define_response._asdict() + else: + return dict( + zip(range(len(prompt_define_response)), prompt_define_response) + ) + else: + return prompt_define_response diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py index d9a8f60c1..81af1b3b1 100644 --- a/pilot/scene/chat_agent/chat.py +++ b/pilot/scene/chat_agent/chat.py @@ -11,6 +11,7 @@ from .prompt import prompt from pilot.component import ComponentType from pilot.base_modules.agent.controller import ModuleAgent +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -51,6 +52,7 @@ def __init__(self, chat_param: Dict): self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) + @trace() async def generate_input_values(self) -> Dict[str, str]: input_values = { "user_goal": self.current_user_input, @@ -63,7 +65,10 @@ async def generate_input_values(self) -> Dict[str, str]: def stream_plugin_call(self, text): text = text.replace("\n", " ") - return self.api_call.run(text) + with root_tracer.start_span( + "ChatAgent.stream_plugin_call.api_call", metadata={"text": text} + ): + return self.api_call.run(text) def __list_to_prompt_str(self, list: List) -> str: return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 211aa7c04..6771fb3fc 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -13,6 +13,7 @@ from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -53,6 +54,7 @@ def __load_dashboard_template(self, template_name): data = f.read() return json.loads(data) + @trace() async def generate_input_values(self) -> Dict: try: from pilot.summary.db_summary_client import DBSummaryClient diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 064e7586c..fefc8142c 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -14,6 +14,7 @@ from pilot.common.path_utils import has_path from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.base_modules.agent.common.schema import Status +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -62,6 +63,7 @@ def _generate_numbered_list(self) -> str: # ] return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) + @trace() async def generate_input_values(self) -> Dict: input_values = { "user_input": self.current_user_input, @@ -88,4 +90,9 @@ async def prepare(self): def stream_plugin_call(self, text): text = text.replace("\n", " ") - return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex) + with root_tracer.start_span( + "ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text} + ): + return self.api_call.run_display_sql( + text, self.excel_reader.get_df_by_sql_ex + ) diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py index f05221eba..7d1730ad0 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py @@ -13,6 +13,7 @@ from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.json_utils.utilities import DateTimeEncoder from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -44,6 +45,7 @@ def __init__( if parent_mode: self.current_message.chat_mode = parent_mode.value() + @trace() async def generate_input_values(self) -> Dict: # colunms, datas = self.excel_reader.get_sample_data() colunms, datas = await blocking_func_to_async( diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index d9b901772..4d4bf3c0c 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -35,10 +36,13 @@ def __init__(self, chat_param: Dict): raise ValueError( f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" ) - - self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) + with root_tracer.start_span( + "ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name} + ): + self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.top_k: int = 200 + @trace() async def generate_input_values(self) -> Dict: """ generate input values @@ -55,13 +59,14 @@ async def generate_input_values(self) -> Dict: # query=self.current_user_input, # topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, # ) - table_infos = await blocking_func_to_async( - self._executor, - client.get_db_summary, - self.db_name, - self.current_user_input, - CFG.KNOWLEDGE_SEARCH_TOP_SIZE, - ) + with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"): + table_infos = await blocking_func_to_async( + self._executor, + client.get_db_summary, + self.db_name, + self.current_user_input, + CFG.KNOWLEDGE_SEARCH_TOP_SIZE, + ) except Exception as e: print("db summary find error!" + str(e)) if not table_infos: @@ -80,4 +85,8 @@ async def generate_input_values(self) -> Dict: def do_action(self, prompt_response): print(f"do_action:{prompt_response}") - return self.database.run(prompt_response.sql) + with root_tracer.start_span( + "ChatWithDbAutoExecute.do_action.run_sql", + metadata=prompt_response.to_dict(), + ): + return self.database.run(prompt_response.sql) diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 577cac1ef..e583d945a 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -12,6 +12,9 @@ class SqlAction(NamedTuple): sql: str thoughts: Dict + def to_dict(self) -> Dict[str, Dict]: + return {"sql": self.sql, "thoughts": self.thoughts} + logger = logging.getLogger(__name__) diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 5ae76d37d..fde28d91b 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config from pilot.scene.chat_db.professional_qa.prompt import prompt from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -39,6 +40,7 @@ def __init__(self, chat_param: Dict): else len(self.tables) ) + @trace() async def generate_input_values(self) -> Dict: table_info = "" dialect = "mysql" diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index bdd78d7b7..2615918ff 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -6,6 +6,7 @@ from pilot.base_modules.agent.commands.command import execute_command from pilot.base_modules.agent import PluginPromptGenerator from .prompt import prompt +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -50,6 +51,7 @@ def __init__(self, chat_param: Dict): self.plugins_prompt_generator ) + @trace() async def generate_input_values(self) -> Dict: input_values = { "input": self.current_user_input, diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py index 07a64aea9..f7c81bd77 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -4,6 +4,7 @@ from pilot.configs.config import Config from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -31,6 +32,7 @@ def __init__( self.db_input = db_select self.db_summary = db_summary + @trace() async def generate_input_values(self) -> Dict: input_values = { "db_input": self.db_input, diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index d57b32b25..a9c63b268 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -15,6 +15,7 @@ from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.server.knowledge.service import KnowledgeService from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -92,6 +93,7 @@ def knowledge_reference_call(self, text): """return reference""" return text + f"\n\n{self.parse_source_view(self.sources)}" + @trace() async def generate_input_values(self) -> Dict: if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py index 5999d5c3c..0191ef943 100644 --- a/pilot/scene/chat_normal/chat.py +++ b/pilot/scene/chat_normal/chat.py @@ -5,6 +5,7 @@ from pilot.configs.config import Config from pilot.scene.chat_normal.prompt import prompt +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -21,6 +22,7 @@ def __init__(self, chat_param: Dict): chat_param=chat_param, ) + @trace() async def generate_input_values(self) -> Dict: input_values = {"input": self.current_user_input} return input_values diff --git a/pilot/utils/executor_utils.py b/pilot/utils/executor_utils.py index 2aac0d04d..26ee3c66e 100644 --- a/pilot/utils/executor_utils.py +++ b/pilot/utils/executor_utils.py @@ -1,5 +1,6 @@ from typing import Callable, Awaitable, Any import asyncio +import contextvars from abc import ABC, abstractmethod from concurrent.futures import Executor, ThreadPoolExecutor from functools import partial @@ -55,6 +56,12 @@ async def blocking_func_to_async( """ if asyncio.iscoroutinefunction(func): raise ValueError(f"The function {func} is not blocking function") + + # This function will be called within the new thread, capturing the current context + ctx = contextvars.copy_context() + + def run_with_context(): + return ctx.run(partial(func, *args, **kwargs)) + loop = asyncio.get_event_loop() - sync_function_noargs = partial(func, *args, **kwargs) - return await loop.run_in_executor(executor, sync_function_noargs) + return await loop.run_in_executor(executor, run_with_context) diff --git a/pilot/utils/tracer/__init__.py b/pilot/utils/tracer/__init__.py index 16509ff43..cdb536f79 100644 --- a/pilot/utils/tracer/__init__.py +++ b/pilot/utils/tracer/__init__.py @@ -10,6 +10,7 @@ from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage from pilot.utils.tracer.tracer_impl import ( root_tracer, + trace, initialize_tracer, DefaultTracer, TracerManager, @@ -26,6 +27,7 @@ "MemorySpanStorage", "FileSpanStorage", "root_tracer", + "trace", "initialize_tracer", "DefaultTracer", "TracerManager", diff --git a/pilot/utils/tracer/tracer_cli.py b/pilot/utils/tracer/tracer_cli.py index 7df18f516..3fb9cba31 100644 --- a/pilot/utils/tracer/tracer_cli.py +++ b/pilot/utils/tracer/tracer_cli.py @@ -303,8 +303,6 @@ def chat( print(table.get_formatted_string(out_format=output, **out_kwargs)) if sys_table: print(sys_table.get_formatted_string(out_format=output, **out_kwargs)) - if hide_conv: - return if not found_trace_id: print(f"Can't found conversation with trace_id: {trace_id}") @@ -315,9 +313,12 @@ def chat( trace_spans = [s for s in reversed(trace_spans)] hierarchy = _build_trace_hierarchy(trace_spans) if tree: - print("\nInvoke Trace Tree:\n") + print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n") _print_trace_hierarchy(hierarchy) + if hide_conv: + return + trace_spans = _get_ordered_trace_from(hierarchy) table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details") split_long_text = output == "text" @@ -340,36 +341,43 @@ def chat( table.add_row(["echo", metadata.get("echo")]) elif "error" in metadata: table.add_row(["BaseChat Error", metadata.get("error")]) - if op == "BaseChat.nostream_call" and not sp["end_time"]: - if "model_output" in metadata: - table.add_row( - [ - "BaseChat model_output", - split_string_by_terminal_width( - metadata.get("model_output").get("text"), - split=split_long_text, - ), - ] - ) - if "ai_response_text" in metadata: - table.add_row( - [ - "BaseChat ai_response_text", - split_string_by_terminal_width( - metadata.get("ai_response_text"), split=split_long_text - ), - ] - ) - if "prompt_define_response" in metadata: - table.add_row( - [ - "BaseChat prompt_define_response", - split_string_by_terminal_width( - metadata.get("prompt_define_response"), - split=split_long_text, - ), - ] + if op == "BaseChat.do_action" and not sp["end_time"]: + if "model_output" in metadata: + table.add_row( + [ + "BaseChat model_output", + split_string_by_terminal_width( + metadata.get("model_output").get("text"), + split=split_long_text, + ), + ] + ) + if "ai_response_text" in metadata: + table.add_row( + [ + "BaseChat ai_response_text", + split_string_by_terminal_width( + metadata.get("ai_response_text"), split=split_long_text + ), + ] + ) + if "prompt_define_response" in metadata: + prompt_define_response = metadata.get("prompt_define_response") or "" + if isinstance(prompt_define_response, dict) or isinstance( + prompt_define_response, type([]) + ): + prompt_define_response = json.dumps( + prompt_define_response, ensure_ascii=False ) + table.add_row( + [ + "BaseChat prompt_define_response", + split_string_by_terminal_width( + prompt_define_response, + split=split_long_text, + ), + ] + ) if op == "DefaultModelWorker_call.generate_stream_func": if not sp["end_time"]: table.add_row(["llm_adapter", metadata.get("llm_adapter")]) diff --git a/pilot/utils/tracer/tracer_impl.py b/pilot/utils/tracer/tracer_impl.py index bda25ab4d..2358863bf 100644 --- a/pilot/utils/tracer/tracer_impl.py +++ b/pilot/utils/tracer/tracer_impl.py @@ -1,6 +1,9 @@ from typing import Dict, Optional from contextvars import ContextVar from functools import wraps +import asyncio +import inspect + from pilot.component import SystemApp, ComponentType from pilot.utils.tracer.base import ( @@ -154,18 +157,42 @@ def get_current_span_id(self) -> Optional[str]: root_tracer: TracerManager = TracerManager() -def trace(operation_name: str, **trace_kwargs): +def trace(operation_name: Optional[str] = None, **trace_kwargs): def decorator(func): @wraps(func) - async def wrapper(*args, **kwargs): - with root_tracer.start_span(operation_name, **trace_kwargs): + def sync_wrapper(*args, **kwargs): + name = ( + operation_name if operation_name else _parse_operation_name(func, *args) + ) + with root_tracer.start_span(name, **trace_kwargs): + return func(*args, **kwargs) + + @wraps(func) + async def async_wrapper(*args, **kwargs): + name = ( + operation_name if operation_name else _parse_operation_name(func, *args) + ) + with root_tracer.start_span(name, **trace_kwargs): return await func(*args, **kwargs) - return wrapper + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper return decorator +def _parse_operation_name(func, *args): + self_name = None + if inspect.signature(func).parameters.get("self"): + self_name = args[0].__class__.__name__ + func_name = func.__name__ + if self_name: + return f"{self_name}.{func_name}" + return func_name + + def initialize_tracer( system_app: SystemApp, tracer_filename: str,