Skip to content

Commit

Permalink
feat(core): Enhance server request processing performance (#722)
Browse files Browse the repository at this point in the history
Close #720 
**Others**: 
- Fix chat tracer no spans bug
- Modify AutoDL setup script
  • Loading branch information
Aries-ckt authored Oct 24, 2023
2 parents 185d436 + e5e4f54 commit 96a4867
Show file tree
Hide file tree
Showing 25 changed files with 201 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pandas import DataFrame

from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
import pandas as pd
import uuid
import os
Expand Down
2 changes: 1 addition & 1 deletion pilot/base_modules/agent/hub/agent_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..common.schema import PluginStorageType
from ..plugins_util import scan_plugins, update_from_git

logger = logging.getLogger("agent_hub")
logger = logging.getLogger(__name__)
Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = ""
Expand Down
4 changes: 3 additions & 1 deletion pilot/base_modules/agent/plugins_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import git
import threading
import datetime
import logging
from pathlib import Path
from typing import List
from urllib.parse import urlparse
Expand All @@ -19,7 +20,8 @@

from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger

logger = logging.getLogger(__name__)


def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
Expand Down
2 changes: 1 addition & 1 deletion pilot/base_modules/meta_data/meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pilot.configs.config import Config


logger = logging.getLogger("meta_data")
logger = logging.getLogger(__name__)

CFG = Config()
default_db_path = os.path.join(os.getcwd(), "meta_data")
Expand Down
3 changes: 2 additions & 1 deletion pilot/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def remove_color_codes(s: str) -> str:
return ansi_escape.sub("", s)


logger: Logger = Logger()
# Remove current logger
# logger: Logger = Logger()


def print_assistant_thoughts(
Expand Down
11 changes: 10 additions & 1 deletion pilot/memory/chat_history/chat_hisotry_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import MemoryStoreType
from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory

CFG = Config()

Expand All @@ -18,7 +19,15 @@ def __init__(self):
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory

def get_store_instance(self, chat_session_id):
def get_store_instance(self, chat_session_id: str) -> BaseChatHistoryMemory:
"""New store instance for store chat histories
Args:
chat_session_id (str): conversation session id
Returns:
BaseChatHistoryMemory: Store instance
"""
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
chat_session_id
)
Expand Down
2 changes: 1 addition & 1 deletion pilot/memory/chat_history/store_type/meta_db_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pilot.memory.chat_history.base import MemoryStoreType

CFG = Config()
logger = logging.getLogger("db_chat_history")
logger = logging.getLogger(__name__)


class DbHistoryMemory(BaseChatHistoryMemory):
Expand Down
43 changes: 35 additions & 8 deletions pilot/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from fastapi.exceptions import RequestValidationError
from typing import List
import tempfile
from concurrent.futures import Executor

from pilot.component import ComponentType
from pilot.openapi.api_view_model import (
Expand Down Expand Up @@ -46,6 +47,8 @@
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer, SpanType
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async

router = APIRouter()
CFG = Config()
Expand Down Expand Up @@ -129,6 +132,13 @@ def get_worker_manager() -> WorkerManager:
return worker_manager


def get_executor() -> Executor:
"""Get the global default executor"""
return CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()


@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
Expand Down Expand Up @@ -158,6 +168,7 @@ async def async_db_summary_embedding(db_name, db_type):
@router.post("/v1/chat/db/test/connect", response_model=Result[bool])
async def test_connect(db_config: DBConfig = Body()):
try:
# TODO Change the synchronous call to the asynchronous call
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
return Result.succ(True)
except Exception as e:
Expand All @@ -166,6 +177,7 @@ async def test_connect(db_config: DBConfig = Body()):

@router.post("/v1/chat/db/summary", response_model=Result[bool])
async def db_summary(db_name: str, db_type: str):
# TODO Change the synchronous call to the asynchronous call
async_db_summary_embedding(db_name, db_type)
return Result.succ(True)

Expand All @@ -185,6 +197,7 @@ async def db_support_types():
async def dialogue_list(user_id: str = None):
dialogues: List = []
chat_history_service = ChatHistory()
# TODO Change the synchronous call to the asynchronous call
datas = chat_history_service.get_store_cls().conv_list(user_id)
for item in datas:
conv_uid = item.get("conv_uid")
Expand Down Expand Up @@ -285,7 +298,7 @@ async def params_load(
select_param=doc_file.filename,
model_name=model_name,
)
chat: BaseChat = get_chat_instance(dialogue)
chat: BaseChat = await get_chat_instance(dialogue)
resp = await chat.prepare()

### refresh messages
Expand All @@ -299,6 +312,7 @@ async def params_load(
async def dialogue_delete(con_uid: str):
history_fac = ChatHistory()
history_mem = history_fac.get_store_instance(con_uid)
# TODO Change the synchronous call to the asynchronous call
history_mem.delete()
return Result.succ(None)

Expand All @@ -324,10 +338,11 @@ def get_hist_messages(conv_uid: str):
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}")
# TODO Change the synchronous call to the asynchronous call
return Result.succ(get_hist_messages(con_uid))


def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
logger.info(f"get_chat_instance:{dialogue}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value()
Expand All @@ -346,8 +361,14 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
"select_param": dialogue.select_param,
"model_name": dialogue.model_name,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(
dialogue.chat_mode, **{"chat_param": chat_param}
# chat: BaseChat = CHAT_FACTORY.get_implementation(
# dialogue.chat_mode, **{"chat_param": chat_param}
# )
chat: BaseChat = await blocking_func_to_async(
get_executor(),
CHAT_FACTORY.get_implementation,
dialogue.chat_mode,
**{"chat_param": chat_param},
)
return chat

Expand All @@ -357,7 +378,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL
logger.info(f"chat_prepare:{dialogue}")
## check conv_uid
chat: BaseChat = get_chat_instance(dialogue)
chat: BaseChat = await get_chat_instance(dialogue)
if len(chat.history_message) > 0:
return Result.succ(None)
resp = await chat.prepare()
Expand All @@ -369,7 +390,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
chat: BaseChat = get_chat_instance(dialogue)
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
):
chat: BaseChat = await get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
headers = {
Expand Down Expand Up @@ -420,8 +444,9 @@ async def model_supports(worker_manager: WorkerManager = Depends(get_worker_mana


async def no_stream_generator(chat):
msg = await chat.nostream_call()
yield f"data: {msg}\n\n"
with root_tracer.start_span("no_stream_generator"):
msg = await chat.nostream_call()
yield f"data: {msg}\n\n"


async def stream_generator(chat, incremental: bool, model_name: str):
Expand All @@ -438,6 +463,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
Yields:
_type_: streaming responses
"""
span = root_tracer.start_span("stream_generator")
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."

stream_id = f"chatcmpl-{str(uuid.uuid1())}"
Expand All @@ -463,6 +489,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"
span.end()


def message2Vo(message: dict, order, model_name) -> MessageVo:
Expand Down
39 changes: 30 additions & 9 deletions pilot/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory

Expand Down Expand Up @@ -80,6 +81,10 @@ def __init__(self, chat_param: Dict):
self.current_message.param_type = self.chat_mode.param_types()[0]
self.current_message.param_value = chat_param["select_param"]
self.current_tokens_used: int = 0
# The executor to submit blocking function
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -92,8 +97,14 @@ def chat_type(self) -> str:
raise NotImplementedError("Not supported for this chat type.")

@abstractmethod
def generate_input_values(self):
pass
async def generate_input_values(self) -> Dict:
"""Generate input to LLM
Please note that you must not perform any blocking operations in this function
Returns:
a dictionary to be formatted by prompt template
"""

def do_action(self, prompt_response):
return prompt_response
Expand All @@ -116,8 +127,8 @@ def get_llm_speak(self, prompt_define_response):
speak_to_user = prompt_define_response
return speak_to_user

def __call_base(self):
input_values = self.generate_input_values()
async def __call_base(self):
input_values = await self.generate_input_values()
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
Expand Down Expand Up @@ -159,7 +170,7 @@ async def check_iterator_end(iterator):

async def stream_call(self):
# TODO Retry when server connection error
payload = self.__call_base()
payload = await self.__call_base()

self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
Expand Down Expand Up @@ -190,7 +201,7 @@ async def stream_call(self):
self.memory.append(self.current_message)

async def nostream_call(self):
payload = self.__call_base()
payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
Expand All @@ -216,14 +227,24 @@ async def nostream_call(self):
)
)
### run
result = self.do_action(prompt_define_response)
# result = self.do_action(prompt_define_response)
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)

### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)

view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
# view_message = self.prompt_template.output_parser.parse_view_response(
# speak_to_user, result
# )
view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
)

view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion pilot/scene/chat_agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, chat_param: Dict):

self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)

def generate_input_values(self):
async def generate_input_values(self) -> Dict[str, str]:
input_values = {
"user_goal": self.current_user_input,
"expand_constraints": self.__list_to_prompt_str(
Expand Down
5 changes: 0 additions & 5 deletions pilot/scene/chat_agent/out_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import json
from typing import Dict, NamedTuple
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR


logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")


class PluginAction(NamedTuple):
Expand Down
14 changes: 11 additions & 3 deletions pilot/scene/chat_dashboard/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.utils.executor_utils import blocking_func_to_async

CFG = Config()

Expand Down Expand Up @@ -52,17 +53,24 @@ def __load_dashboard_template(self, template_name):
data = f.read()
return json.loads(data)

def generate_input_values(self):
async def generate_input_values(self) -> Dict:
try:
from pilot.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")

client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
table_infos = client.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
table_infos = await blocking_func_to_async(
self._executor,
client.get_similar_tables,
self.db_name,
self.current_user_input,
self.top_k,
)
# table_infos = client.get_similar_tables(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
print("dashboard vector find tables:{}", table_infos)
except Exception as e:
print("db summary find error!" + str(e))
Expand Down
2 changes: 1 addition & 1 deletion pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _generate_numbered_list(self) -> str:
# ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))

def generate_input_values(self):
async def generate_input_values(self) -> Dict:
input_values = {
"user_input": self.current_user_input,
"table_name": self.excel_reader.table_name,
Expand Down
Loading

0 comments on commit 96a4867

Please sign in to comment.