diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql
index e38f731d6..a6f1bc478 100644
--- a/assets/schema/knowledge_management.sql
+++ b/assets/schema/knowledge_management.sql
@@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` (
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`result` TEXT NULL COMMENT 'knowledge content',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
+ `summary` LONGTEXT NULL COMMENT 'knowledge summary',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 34f294c31..4529d3cb4 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -13,6 +13,7 @@
from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
+from pilot.utils.tracer import root_tracer, trace
from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
@@ -38,6 +39,7 @@ class Config:
arbitrary_types_allowed = True
+ @trace("BaseChat.__init__")
def __init__(self, chat_param: Dict):
"""Chat Module Initialization
Args:
@@ -143,7 +145,14 @@ async def __call_base(self):
)
self.current_message.tokens = 0
if self.prompt_template.template:
- current_prompt = self.prompt_template.format(**input_values)
+ metadata = {
+ "template_scene": self.prompt_template.template_scene,
+ "input_values": input_values,
+ }
+ with root_tracer.start_span(
+ "BaseChat.__call_base.prompt_template.format", metadata=metadata
+ ):
+ current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt)
llm_messages = self.generate_llm_messages()
@@ -175,6 +184,14 @@ async def check_iterator_end(iterator):
except StopAsyncIteration:
return True # 迭代器已经执行结束
+ def _get_span_metadata(self, payload: Dict) -> Dict:
+ metadata = {k: v for k, v in payload.items()}
+ del metadata["prompt"]
+ metadata["messages"] = list(
+ map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
+ )
+ return metadata
+
async def stream_call(self):
# TODO Retry when server connection error
payload = await self.__call_base()
@@ -182,6 +199,10 @@ async def stream_call(self):
self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
+ span = root_tracer.start_span(
+ "BaseChat.stream_call", metadata=self._get_span_metadata(payload)
+ )
+ payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
@@ -199,6 +220,7 @@ async def stream_call(self):
self.current_message.add_ai_message(msg)
view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg)
+ span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
@@ -206,12 +228,17 @@ async def stream_call(self):
f"""ERROR!{str(e)}\n {ai_response_text} """
)
### store current conversation
+ span.end(metadata={"error": str(e)})
self.memory.append(self.current_message)
async def nostream_call(self):
payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
+ span = root_tracer.start_span(
+ "BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
+ )
+ payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
@@ -219,7 +246,8 @@ async def nostream_call(self):
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
- model_output = await worker_manager.generate(payload)
+ with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
+ model_output = await worker_manager.generate(payload)
### output parse
ai_response_text = (
@@ -234,11 +262,18 @@ async def nostream_call(self):
ai_response_text
)
)
- ### run
- # result = self.do_action(prompt_define_response)
- result = await blocking_func_to_async(
- self._executor, self.do_action, prompt_define_response
- )
+ metadata = {
+ "model_output": model_output.to_dict(),
+ "ai_response_text": ai_response_text,
+ "prompt_define_response": self._parse_prompt_define_response(
+ prompt_define_response
+ ),
+ }
+ with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
+ ### run
+ result = await blocking_func_to_async(
+ self._executor, self.do_action, prompt_define_response
+ )
### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)
@@ -255,12 +290,14 @@ async def nostream_call(self):
view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message)
+ span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
self.current_message.add_view_message(
f"""ERROR!{str(e)}\n {ai_response_text} """
)
+ span.end(metadata={"error": str(e)})
### store dialogue
self.memory.append(self.current_message)
return self.current_ai_response()
@@ -513,3 +550,21 @@ def generate(self, p) -> str:
"""
pass
+
+ def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
+ if not prompt_define_response:
+ return ""
+ if isinstance(prompt_define_response, str) or isinstance(
+ prompt_define_response, dict
+ ):
+ return prompt_define_response
+ if isinstance(prompt_define_response, tuple):
+ if hasattr(prompt_define_response, "_asdict"):
+ # namedtuple
+ return prompt_define_response._asdict()
+ else:
+ return dict(
+ zip(range(len(prompt_define_response)), prompt_define_response)
+ )
+ else:
+ return prompt_define_response
diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py
index d9a8f60c1..81af1b3b1 100644
--- a/pilot/scene/chat_agent/chat.py
+++ b/pilot/scene/chat_agent/chat.py
@@ -11,6 +11,7 @@
from .prompt import prompt
from pilot.component import ComponentType
from pilot.base_modules.agent.controller import ModuleAgent
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -51,6 +52,7 @@ def __init__(self, chat_param: Dict):
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
+ @trace()
async def generate_input_values(self) -> Dict[str, str]:
input_values = {
"user_goal": self.current_user_input,
@@ -63,7 +65,10 @@ async def generate_input_values(self) -> Dict[str, str]:
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
- return self.api_call.run(text)
+ with root_tracer.start_span(
+ "ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
+ ):
+ return self.api_call.run(text)
def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py
index 211aa7c04..6771fb3fc 100644
--- a/pilot/scene/chat_dashboard/chat.py
+++ b/pilot/scene/chat_dashboard/chat.py
@@ -13,6 +13,7 @@
from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.utils.executor_utils import blocking_func_to_async
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -53,6 +54,7 @@ def __load_dashboard_template(self, template_name):
data = f.read()
return json.loads(data)
+ @trace()
async def generate_input_values(self) -> Dict:
try:
from pilot.summary.db_summary_client import DBSummaryClient
diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
index 064e7586c..fefc8142c 100644
--- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
+++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
@@ -14,6 +14,7 @@
from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.base_modules.agent.common.schema import Status
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -62,6 +63,7 @@ def _generate_numbered_list(self) -> str:
# ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
+ @trace()
async def generate_input_values(self) -> Dict:
input_values = {
"user_input": self.current_user_input,
@@ -88,4 +90,9 @@ async def prepare(self):
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
- return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
+ with root_tracer.start_span(
+ "ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
+ ):
+ return self.api_call.run_display_sql(
+ text, self.excel_reader.get_df_by_sql_ex
+ )
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py
index f05221eba..7d1730ad0 100644
--- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py
+++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py
@@ -13,6 +13,7 @@
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder
from pilot.utils.executor_utils import blocking_func_to_async
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -44,6 +45,7 @@ def __init__(
if parent_mode:
self.current_message.chat_mode = parent_mode.value()
+ @trace()
async def generate_input_values(self) -> Dict:
# colunms, datas = self.excel_reader.get_sample_data()
colunms, datas = await blocking_func_to_async(
diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py
index d9b901772..4d4bf3c0c 100644
--- a/pilot/scene/chat_db/auto_execute/chat.py
+++ b/pilot/scene/chat_db/auto_execute/chat.py
@@ -6,6 +6,7 @@
from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -35,10 +36,13 @@ def __init__(self, chat_param: Dict):
raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
)
-
- self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
+ with root_tracer.start_span(
+ "ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
+ ):
+ self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200
+ @trace()
async def generate_input_values(self) -> Dict:
"""
generate input values
@@ -55,13 +59,14 @@ async def generate_input_values(self) -> Dict:
# query=self.current_user_input,
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
# )
- table_infos = await blocking_func_to_async(
- self._executor,
- client.get_db_summary,
- self.db_name,
- self.current_user_input,
- CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
- )
+ with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
+ table_infos = await blocking_func_to_async(
+ self._executor,
+ client.get_db_summary,
+ self.db_name,
+ self.current_user_input,
+ CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
+ )
except Exception as e:
print("db summary find error!" + str(e))
if not table_infos:
@@ -80,4 +85,8 @@ async def generate_input_values(self) -> Dict:
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
- return self.database.run(prompt_response.sql)
+ with root_tracer.start_span(
+ "ChatWithDbAutoExecute.do_action.run_sql",
+ metadata=prompt_response.to_dict(),
+ ):
+ return self.database.run(prompt_response.sql)
diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py
index 577cac1ef..e583d945a 100644
--- a/pilot/scene/chat_db/auto_execute/out_parser.py
+++ b/pilot/scene/chat_db/auto_execute/out_parser.py
@@ -12,6 +12,9 @@ class SqlAction(NamedTuple):
sql: str
thoughts: Dict
+ def to_dict(self) -> Dict[str, Dict]:
+ return {"sql": self.sql, "thoughts": self.thoughts}
+
logger = logging.getLogger(__name__)
diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py
index 5ae76d37d..fde28d91b 100644
--- a/pilot/scene/chat_db/professional_qa/chat.py
+++ b/pilot/scene/chat_db/professional_qa/chat.py
@@ -6,6 +6,7 @@
from pilot.configs.config import Config
from pilot.scene.chat_db.professional_qa.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -39,6 +40,7 @@ def __init__(self, chat_param: Dict):
else len(self.tables)
)
+ @trace()
async def generate_input_values(self) -> Dict:
table_info = ""
dialect = "mysql"
diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py
index bdd78d7b7..2615918ff 100644
--- a/pilot/scene/chat_execution/chat.py
+++ b/pilot/scene/chat_execution/chat.py
@@ -6,6 +6,7 @@
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent import PluginPromptGenerator
from .prompt import prompt
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -50,6 +51,7 @@ def __init__(self, chat_param: Dict):
self.plugins_prompt_generator
)
+ @trace()
async def generate_input_values(self) -> Dict:
input_values = {
"input": self.current_user_input,
diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py
index 07a64aea9..f7c81bd77 100644
--- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py
+++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py
@@ -4,6 +4,7 @@
from pilot.configs.config import Config
from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -31,6 +32,7 @@ def __init__(
self.db_input = db_select
self.db_summary = db_summary
+ @trace()
async def generate_input_values(self) -> Dict:
input_values = {
"db_input": self.db_input,
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index d57b32b25..a9c63b268 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -15,6 +15,7 @@
from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.server.knowledge.service import KnowledgeService
from pilot.utils.executor_utils import blocking_func_to_async
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -92,6 +93,7 @@ def knowledge_reference_call(self, text):
"""return reference"""
return text + f"\n\n{self.parse_source_view(self.sources)}"
+ @trace()
async def generate_input_values(self) -> Dict:
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py
index 5999d5c3c..0191ef943 100644
--- a/pilot/scene/chat_normal/chat.py
+++ b/pilot/scene/chat_normal/chat.py
@@ -5,6 +5,7 @@
from pilot.configs.config import Config
from pilot.scene.chat_normal.prompt import prompt
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -21,6 +22,7 @@ def __init__(self, chat_param: Dict):
chat_param=chat_param,
)
+ @trace()
async def generate_input_values(self) -> Dict:
input_values = {"input": self.current_user_input}
return input_values
diff --git a/pilot/utils/executor_utils.py b/pilot/utils/executor_utils.py
index 2aac0d04d..26ee3c66e 100644
--- a/pilot/utils/executor_utils.py
+++ b/pilot/utils/executor_utils.py
@@ -1,5 +1,6 @@
from typing import Callable, Awaitable, Any
import asyncio
+import contextvars
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
@@ -55,6 +56,12 @@ async def blocking_func_to_async(
"""
if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function")
+
+ # This function will be called within the new thread, capturing the current context
+ ctx = contextvars.copy_context()
+
+ def run_with_context():
+ return ctx.run(partial(func, *args, **kwargs))
+
loop = asyncio.get_event_loop()
- sync_function_noargs = partial(func, *args, **kwargs)
- return await loop.run_in_executor(executor, sync_function_noargs)
+ return await loop.run_in_executor(executor, run_with_context)
diff --git a/pilot/utils/tracer/__init__.py b/pilot/utils/tracer/__init__.py
index 16509ff43..cdb536f79 100644
--- a/pilot/utils/tracer/__init__.py
+++ b/pilot/utils/tracer/__init__.py
@@ -10,6 +10,7 @@
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
from pilot.utils.tracer.tracer_impl import (
root_tracer,
+ trace,
initialize_tracer,
DefaultTracer,
TracerManager,
@@ -26,6 +27,7 @@
"MemorySpanStorage",
"FileSpanStorage",
"root_tracer",
+ "trace",
"initialize_tracer",
"DefaultTracer",
"TracerManager",
diff --git a/pilot/utils/tracer/tracer_cli.py b/pilot/utils/tracer/tracer_cli.py
index 7df18f516..3fb9cba31 100644
--- a/pilot/utils/tracer/tracer_cli.py
+++ b/pilot/utils/tracer/tracer_cli.py
@@ -303,8 +303,6 @@ def chat(
print(table.get_formatted_string(out_format=output, **out_kwargs))
if sys_table:
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
- if hide_conv:
- return
if not found_trace_id:
print(f"Can't found conversation with trace_id: {trace_id}")
@@ -315,9 +313,12 @@ def chat(
trace_spans = [s for s in reversed(trace_spans)]
hierarchy = _build_trace_hierarchy(trace_spans)
if tree:
- print("\nInvoke Trace Tree:\n")
+ print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
_print_trace_hierarchy(hierarchy)
+ if hide_conv:
+ return
+
trace_spans = _get_ordered_trace_from(hierarchy)
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
split_long_text = output == "text"
@@ -340,36 +341,43 @@ def chat(
table.add_row(["echo", metadata.get("echo")])
elif "error" in metadata:
table.add_row(["BaseChat Error", metadata.get("error")])
- if op == "BaseChat.nostream_call" and not sp["end_time"]:
- if "model_output" in metadata:
- table.add_row(
- [
- "BaseChat model_output",
- split_string_by_terminal_width(
- metadata.get("model_output").get("text"),
- split=split_long_text,
- ),
- ]
- )
- if "ai_response_text" in metadata:
- table.add_row(
- [
- "BaseChat ai_response_text",
- split_string_by_terminal_width(
- metadata.get("ai_response_text"), split=split_long_text
- ),
- ]
- )
- if "prompt_define_response" in metadata:
- table.add_row(
- [
- "BaseChat prompt_define_response",
- split_string_by_terminal_width(
- metadata.get("prompt_define_response"),
- split=split_long_text,
- ),
- ]
+ if op == "BaseChat.do_action" and not sp["end_time"]:
+ if "model_output" in metadata:
+ table.add_row(
+ [
+ "BaseChat model_output",
+ split_string_by_terminal_width(
+ metadata.get("model_output").get("text"),
+ split=split_long_text,
+ ),
+ ]
+ )
+ if "ai_response_text" in metadata:
+ table.add_row(
+ [
+ "BaseChat ai_response_text",
+ split_string_by_terminal_width(
+ metadata.get("ai_response_text"), split=split_long_text
+ ),
+ ]
+ )
+ if "prompt_define_response" in metadata:
+ prompt_define_response = metadata.get("prompt_define_response") or ""
+ if isinstance(prompt_define_response, dict) or isinstance(
+ prompt_define_response, type([])
+ ):
+ prompt_define_response = json.dumps(
+ prompt_define_response, ensure_ascii=False
)
+ table.add_row(
+ [
+ "BaseChat prompt_define_response",
+ split_string_by_terminal_width(
+ prompt_define_response,
+ split=split_long_text,
+ ),
+ ]
+ )
if op == "DefaultModelWorker_call.generate_stream_func":
if not sp["end_time"]:
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
diff --git a/pilot/utils/tracer/tracer_impl.py b/pilot/utils/tracer/tracer_impl.py
index bda25ab4d..2358863bf 100644
--- a/pilot/utils/tracer/tracer_impl.py
+++ b/pilot/utils/tracer/tracer_impl.py
@@ -1,6 +1,9 @@
from typing import Dict, Optional
from contextvars import ContextVar
from functools import wraps
+import asyncio
+import inspect
+
from pilot.component import SystemApp, ComponentType
from pilot.utils.tracer.base import (
@@ -154,18 +157,42 @@ def get_current_span_id(self) -> Optional[str]:
root_tracer: TracerManager = TracerManager()
-def trace(operation_name: str, **trace_kwargs):
+def trace(operation_name: Optional[str] = None, **trace_kwargs):
def decorator(func):
@wraps(func)
- async def wrapper(*args, **kwargs):
- with root_tracer.start_span(operation_name, **trace_kwargs):
+ def sync_wrapper(*args, **kwargs):
+ name = (
+ operation_name if operation_name else _parse_operation_name(func, *args)
+ )
+ with root_tracer.start_span(name, **trace_kwargs):
+ return func(*args, **kwargs)
+
+ @wraps(func)
+ async def async_wrapper(*args, **kwargs):
+ name = (
+ operation_name if operation_name else _parse_operation_name(func, *args)
+ )
+ with root_tracer.start_span(name, **trace_kwargs):
return await func(*args, **kwargs)
- return wrapper
+ if asyncio.iscoroutinefunction(func):
+ return async_wrapper
+ else:
+ return sync_wrapper
return decorator
+def _parse_operation_name(func, *args):
+ self_name = None
+ if inspect.signature(func).parameters.get("self"):
+ self_name = args[0].__class__.__name__
+ func_name = func.__name__
+ if self_name:
+ return f"{self_name}.{func_name}"
+ return func_name
+
+
def initialize_tracer(
system_app: SystemApp,
tracer_filename: str,