diff --git a/.gitignore b/.gitignore index 15beea989..7f80cceba 100644 --- a/.gitignore +++ b/.gitignore @@ -179,4 +179,8 @@ thirdparty # typescript *.tsbuildinfo -/web/next-env.d.ts \ No newline at end of file +/web/next-env.d.ts + +# Ignore awel DAG visualization files +/examples/**/*.gv +/examples/**/*.gv.pdf \ No newline at end of file diff --git a/Makefile b/Makefile index 74d40075d..2e5f4bff4 100644 --- a/Makefile +++ b/Makefile @@ -67,6 +67,11 @@ pre-commit: fmt test ## Run formatting and unit tests before committing test: $(VENV)/.testenv ## Run unit tests $(VENV_BIN)/pytest dbgpt +.PHONY: test-doc +test-doc: $(VENV)/.testenv ## Run doctests + # -k "not test_" skips tests that are not doctests. + $(VENV_BIN)/pytest --doctest-modules -k "not test_" dbgpt/core + .PHONY: coverage coverage: setup ## Run tests and report coverage $(VENV_BIN)/pytest dbgpt --cov=dbgpt diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index a5e9fb26a..776c460cc 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -102,6 +102,11 @@ def __init__(self, chat_param: Dict): is_stream=True, dag_name="llm_stream_model_dag" ) + # Get the message version, default is v1 in app + # In v1, we will transform the message to compatible format of specific model + # In the future, we will upgrade the message version to v2, and the message will be compatible with all models + self._message_version = chat_param.get("message_version", "v1") + class Config: """Configuration for this pydantic object.""" @@ -185,6 +190,7 @@ async def __call_base(self): "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), "echo": self.llm_echo, + "version": self._message_version, } return payload diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index b8c787af2..6cee5890e 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -6,7 +6,10 @@ CacheValue, ) from dbgpt.core.interface.llm import ( + DefaultMessageConverter, LLMClient, + MessageConverter, + ModelExtraMedata, ModelInferenceMetrics, ModelMetadata, ModelOutput, @@ -14,19 +17,28 @@ ModelRequestContext, ) from dbgpt.core.interface.message import ( + AIMessage, + BaseMessage, ConversationIdentifier, + HumanMessage, MessageIdentifier, MessageStorageItem, ModelMessage, ModelMessageRoleType, OnceConversation, StorageConversation, + SystemMessage, ) from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser from dbgpt.core.interface.prompt import ( + BasePromptTemplate, + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, PromptManager, PromptTemplate, StoragePromptTemplate, + SystemPromptTemplate, ) from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.core.interface.storage import ( @@ -49,14 +61,26 @@ "ModelMessage", "LLMClient", "ModelMessageRoleType", + "ModelExtraMedata", + "MessageConverter", + "DefaultMessageConverter", "OnceConversation", "StorageConversation", + "BaseMessage", + "SystemMessage", + "AIMessage", + "HumanMessage", "MessageStorageItem", "ConversationIdentifier", "MessageIdentifier", "PromptTemplate", "PromptManager", "StoragePromptTemplate", + "BasePromptTemplate", + "ChatPromptTemplate", + "MessagesPlaceholder", + "SystemPromptTemplate", + "HumanPromptTemplate", "BaseOutputParser", "SQLOutputParser", "Serializable", diff --git a/dbgpt/core/awel/__init__.py b/dbgpt/core/awel/__init__.py index 2fcd657cb..01c91153c 100644 --- a/dbgpt/core/awel/__init__.py +++ b/dbgpt/core/awel/__init__.py @@ -7,6 +7,7 @@ """ +import logging from typing import List, Optional from dbgpt.component import SystemApp @@ -39,6 +40,8 @@ ) from .trigger.http_trigger import HttpTrigger +logger = logging.getLogger(__name__) + __all__ = [ "initialize_awel", "DAGContext", @@ -89,14 +92,24 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]): def setup_dev_environment( dags: List[DAG], - host: Optional[str] = "0.0.0.0", + host: Optional[str] = "127.0.0.1", port: Optional[int] = 5555, logging_level: Optional[str] = None, logger_filename: Optional[str] = None, + show_dag_graph: Optional[bool] = True, ) -> None: """Setup a development environment for AWEL. Just using in development environment, not production environment. + + Args: + dags (List[DAG]): The DAGs. + host (Optional[str], optional): The host. Defaults to "127.0.0.1" + port (Optional[int], optional): The port. Defaults to 5555. + logging_level (Optional[str], optional): The logging level. Defaults to None. + logger_filename (Optional[str], optional): The logger filename. Defaults to None. + show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True. + If True, the DAG graph will be saved to a file and open it automatically. """ import uvicorn from fastapi import FastAPI @@ -118,6 +131,15 @@ def setup_dev_environment( system_app.register_instance(trigger_manager) for dag in dags: + if show_dag_graph: + try: + dag_graph_file = dag.visualize_dag() + if dag_graph_file: + logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}") + except Exception as e: + logger.warning( + f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`" + ) for trigger in dag.trigger_nodes: trigger_manager.register_trigger(trigger) trigger_manager.after_register() diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index 07b1db049..70977d536 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -6,8 +6,7 @@ from abc import ABC, abstractmethod from collections import deque from concurrent.futures import Executor -from functools import cache -from typing import Any, Dict, List, Optional, Sequence, Set, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union from dbgpt.component import SystemApp @@ -177,7 +176,10 @@ async def before_dag_run(self): pass async def after_dag_end(self): - """The callback after DAG end""" + """The callback after DAG end, + + This method may be called multiple times, please make sure it is idempotent. + """ pass @@ -299,6 +301,20 @@ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> Non self._downstream.append(node) node._upstream.append(self) + def __repr__(self): + cls_name = self.__class__.__name__ + if self.node_name and self.node_name: + return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})" + if self.node_id: + return f"{cls_name}(node_id={self.node_id})" + if self.node_name: + return f"{cls_name}(node_name={self.node_name})" + else: + return f"{cls_name}" + + def __str__(self): + return self.__repr__() + def _build_task_key(task_name: str, key: str) -> str: return f"{task_name}___$$$$$$___{key}" @@ -496,6 +512,15 @@ async def _after_dag_end(self) -> None: tasks.append(node.after_dag_end()) await asyncio.gather(*tasks) + def print_tree(self) -> None: + """Print the DAG tree""" + _print_format_dag_tree(self) + + def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]: + """Create the DAG graph""" + self.print_tree() + return _visualize_dag(self, view=view, **kwargs) + def __enter__(self): DAGVar.enter_dag(self) return self @@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode for node in stream_nodes: nodes = nodes.union(_get_nodes(node, is_upstream)) return nodes + + +def _print_format_dag_tree(dag: DAG) -> None: + for node in dag.root_nodes: + _print_dag(node) + + +def _print_dag( + node: DAGNode, + level: int = 0, + prefix: str = "", + last: bool = True, + level_dict: Dict[str, Any] = None, +): + if level_dict is None: + level_dict = {} + + connector = " -> " if level != 0 else "" + new_prefix = prefix + if last: + if level != 0: + new_prefix += " " + print(prefix + connector + str(node)) + else: + if level != 0: + new_prefix += "| " + print(prefix + connector + str(node)) + + level_dict[level] = level_dict.get(level, 0) + 1 + num_children = len(node.downstream) + for i, child in enumerate(node.downstream): + _print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict) + + +def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None: + def _print_node(node: DAGNode, level: int) -> None: + print(f"{level_sep * level}{node}") + + _apply_root_node(root_nodes, _print_node) + + +def _apply_root_node( + root_nodes: List[DAGNode], + func: Callable[[DAGNode, int], None], +) -> None: + for dag_node in root_nodes: + _handle_dag_nodes(False, 0, dag_node, func) + + +def _handle_dag_nodes( + is_down_to_up: bool, + level: int, + dag_node: DAGNode, + func: Callable[[DAGNode, int], None], +): + if not dag_node: + return + func(dag_node, level) + stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream + level += 1 + for node in stream_nodes: + _handle_dag_nodes(is_down_to_up, level, node, func) + + +def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]: + """Visualize the DAG + + Args: + dag (DAG): The DAG to visualize + view (bool, optional): Whether view the DAG graph. Defaults to True. + + Returns: + Optional[str]: The filename of the DAG graph + """ + try: + from graphviz import Digraph + except ImportError: + logger.warn("Can't import graphviz, skip visualize DAG") + return None + + dot = Digraph(name=dag.dag_id) + # Record the added edges to avoid adding duplicate edges + added_edges = set() + + def add_edges(node: DAGNode): + if node.downstream: + for downstream_node in node.downstream: + # Check if the edge has been added + if (str(node), str(downstream_node)) not in added_edges: + dot.edge(str(node), str(downstream_node)) + added_edges.add((str(node), str(downstream_node))) + add_edges(downstream_node) + + for root in dag.root_nodes: + add_edges(root) + filename = f"dag-vis-{dag.dag_id}.gv" + if "filename" in kwargs: + filename = kwargs["filename"] + del kwargs["filename"] + + if not "directory" in kwargs: + from dbgpt.configs.model_config import LOGDIR + + kwargs["directory"] = LOGDIR + + return dot.render(filename, view=view, **kwargs) diff --git a/dbgpt/core/awel/operator/base.py b/dbgpt/core/awel/operator/base.py index c07eabd14..8fa3e6905 100644 --- a/dbgpt/core/awel/operator/base.py +++ b/dbgpt/core/awel/operator/base.py @@ -46,6 +46,7 @@ async def execute_workflow( node: "BaseOperator", call_data: Optional[CALL_DATA] = None, streaming_call: bool = False, + dag_ctx: Optional[DAGContext] = None, ) -> DAGContext: """Execute the workflow starting from a given operator. @@ -53,7 +54,7 @@ async def execute_workflow( node (RunnableDAGNode): The starting node of the workflow to be executed. call_data (CALL_DATA): The data pass to root operator node. streaming_call (bool): Whether the call is a streaming call. - + dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. Returns: DAGContext: The context after executing the workflow, containing the final state and data. """ @@ -174,18 +175,22 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: TaskOutput[OUT]: The task output after this node has been run. """ - async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT: + async def call( + self, + call_data: Optional[CALL_DATA] = None, + dag_ctx: Optional[DAGContext] = None, + ) -> OUT: """Execute the node and return the output. This method is a high-level wrapper for executing the node. Args: call_data (CALL_DATA): The data pass to root operator node. - + dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. Returns: OUT: The output of the node after execution. """ - out_ctx = await self._runner.execute_workflow(self, call_data) + out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx) return out_ctx.current_task_context.task_output.output def _blocking_call( @@ -209,7 +214,9 @@ def _blocking_call( return loop.run_until_complete(self.call(call_data)) async def call_stream( - self, call_data: Optional[CALL_DATA] = None + self, + call_data: Optional[CALL_DATA] = None, + dag_ctx: Optional[DAGContext] = None, ) -> AsyncIterator[OUT]: """Execute the node and return the output as a stream. @@ -217,12 +224,13 @@ async def call_stream( Args: call_data (CALL_DATA): The data pass to root operator node. + dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. Returns: AsyncIterator[OUT]: An asynchronous iterator over the output stream. """ out_ctx = await self._runner.execute_workflow( - self, call_data, streaming_call=True + self, call_data, streaming_call=True, dag_ctx=dag_ctx ) return out_ctx.current_task_context.task_output.output_stream diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index 8ad3f417c..680b6f974 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -19,17 +19,21 @@ async def execute_workflow( node: BaseOperator, call_data: Optional[CALL_DATA] = None, streaming_call: bool = False, + dag_ctx: Optional[DAGContext] = None, ) -> DAGContext: # Save node output # dag = node.dag - node_outputs: Dict[str, TaskContext] = {} job_manager = JobManager.build_from_end_node(node, call_data) - # Create DAG context - dag_ctx = DAGContext( - streaming_call=streaming_call, - node_to_outputs=node_outputs, - node_name_to_ids=job_manager._node_name_to_ids, - ) + if not dag_ctx: + # Create DAG context + node_outputs: Dict[str, TaskContext] = {} + dag_ctx = DAGContext( + streaming_call=streaming_call, + node_to_outputs=node_outputs, + node_name_to_ids=job_manager._node_name_to_ids, + ) + else: + node_outputs = dag_ctx._node_to_outputs logger.info( f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}" ) diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 8ba40b333..7b705dd77 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -1,14 +1,21 @@ +import collections import copy +import logging import time from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field from typing import Any, AsyncIterator, Dict, List, Optional, Union +from cachetools import TTLCache + +from dbgpt._private.pydantic import BaseModel from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.util import BaseParameters from dbgpt.util.annotations import PublicAPI from dbgpt.util.model_utils import GPUInfo +logger = logging.getLogger(__name__) + @dataclass @PublicAPI(stability="beta") @@ -223,6 +230,29 @@ def get_single_user_message(self) -> Optional[ModelMessage]: raise ValueError("The messages is not a single user message") return messages[0] + @staticmethod + def build_request( + model: str, + messages: List[ModelMessage], + context: Union[ModelRequestContext, Dict[str, Any], BaseModel], + stream: Optional[bool] = False, + **kwargs, + ): + context_dict = None + if isinstance(context, dict): + context_dict = context + elif isinstance(context, BaseModel): + context_dict = context.dict() + if context_dict and "stream" not in context_dict: + context_dict["stream"] = stream + context = ModelRequestContext(**context_dict) + return ModelRequest( + model=model, + messages=messages, + context=context, + **kwargs, + ) + @staticmethod def _build(model: str, prompt: str, **kwargs): return ModelRequest( @@ -271,6 +301,43 @@ def to_openai_messages(self) -> List[Dict[str, Any]]: return ModelMessage.to_openai_messages(messages) +@dataclass +class ModelExtraMedata(BaseParameters): + """A class to represent the extra metadata of a LLM.""" + + prompt_roles: Optional[List[str]] = field( + default_factory=lambda: [ + ModelMessageRoleType.SYSTEM, + ModelMessageRoleType.HUMAN, + ModelMessageRoleType.AI, + ], + metadata={"help": "The roles of the prompt"}, + ) + + prompt_sep: Optional[str] = field( + default="\n", + metadata={"help": "The separator of the prompt between multiple rounds"}, + ) + + # You can see the chat template in your model repo tokenizer config, + # typically in the tokenizer_config.json + prompt_chat_template: Optional[str] = field( + default=None, + metadata={ + "help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating" + }, + ) + + @property + def support_system_message(self) -> bool: + """Whether the model supports system message. + + Returns: + bool: Whether the model supports system message. + """ + return ModelMessageRoleType.SYSTEM in self.prompt_roles + + @dataclass @PublicAPI(stability="beta") class ModelMetadata(BaseParameters): @@ -295,18 +362,294 @@ class ModelMetadata(BaseParameters): default_factory=dict, metadata={"help": "Model metadata"}, ) + ext_metadata: Optional[ModelExtraMedata] = field( + default_factory=ModelExtraMedata, + metadata={"help": "Model extra metadata"}, + ) + + @classmethod + def from_dict( + cls, data: dict, ignore_extra_fields: bool = False + ) -> "ModelMetadata": + if "ext_metadata" in data: + data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"]) + return cls(**data) + + +class MessageConverter(ABC): + """An abstract class for message converter. + + Different LLMs may have different message formats, this class is used to convert the messages + to the format of the LLM. + + Examples: + + >>> from typing import List + >>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType + >>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata + >>> class RemoveSystemMessageConverter(MessageConverter): + ... def convert( + ... self, + ... messages: List[ModelMessage], + ... model_metadata: Optional[ModelMetadata] = None, + ... ) -> List[ModelMessage]: + ... # Convert the messages, merge system messages to the last user message. + ... system_message = None + ... other_messages = [] + ... sep = "\\n" + ... for message in messages: + ... if message.role == ModelMessageRoleType.SYSTEM: + ... system_message = message + ... else: + ... other_messages.append(message) + ... if system_message and other_messages: + ... other_messages[-1].content = ( + ... system_message.content + sep + other_messages[-1].content + ... ) + ... return other_messages + ... + >>> messages = [ + ... ModelMessage( + ... role=ModelMessageRoleType.SYSTEM, + ... content="You are a helpful assistant", + ... ), + ... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"), + ... ] + >>> converter = RemoveSystemMessageConverter() + >>> converted_messages = converter.convert(messages, None) + >>> assert converted_messages == [ + ... ModelMessage( + ... role=ModelMessageRoleType.HUMAN, + ... content="You are a helpful assistant\\nWho are you", + ... ), + ... ] + """ + + @abstractmethod + def convert( + self, + messages: List[ModelMessage], + model_metadata: Optional[ModelMetadata] = None, + ) -> List[ModelMessage]: + """Convert the messages. + + Args: + messages(List[ModelMessage]): The messages. + model_metadata(ModelMetadata): The model metadata. + + Returns: + List[ModelMessage]: The converted messages. + """ + + +class DefaultMessageConverter(MessageConverter): + """The default message converter.""" + + def __init__(self, prompt_sep: Optional[str] = None): + self._prompt_sep = prompt_sep + + def convert( + self, + messages: List[ModelMessage], + model_metadata: Optional[ModelMetadata] = None, + ) -> List[ModelMessage]: + """Convert the messages. + + There are three steps to convert the messages: + + 1. Just keep system, human and AI messages + + 2. Move the last user's message to the end of the list + + 3. Convert the messages to no system message if the model does not support system message + + Args: + messages(List[ModelMessage]): The messages. + model_metadata(ModelMetadata): The model metadata. + + Returns: + List[ModelMessage]: The converted messages. + """ + # 1. Just keep system, human and AI messages + messages = list(filter(lambda m: m.pass_to_model, messages)) + # 2. Move the last user's message to the end of the list + messages = self.move_last_user_message_to_end(messages) + + if not model_metadata or not model_metadata.ext_metadata: + logger.warning("No model metadata, skip message system message conversion") + return messages + if model_metadata.ext_metadata.support_system_message: + # 3. Convert the messages to no system message + return self.convert_to_no_system_message(messages, model_metadata) + return messages + + def convert_to_no_system_message( + self, + messages: List[ModelMessage], + model_metadata: Optional[ModelMetadata] = None, + ) -> List[ModelMessage]: + """Convert the messages to no system message. + + Examples: + >>> # Convert the messages to no system message, just merge system messages to the last user message + >>> from typing import List + >>> from dbgpt.core.interface.message import ( + ... ModelMessage, + ... ModelMessageRoleType, + ... ) + >>> from dbgpt.core.interface.llm import ( + ... DefaultMessageConverter, + ... ModelMetadata, + ... ) + >>> messages = [ + ... ModelMessage( + ... role=ModelMessageRoleType.SYSTEM, + ... content="You are a helpful assistant", + ... ), + ... ModelMessage( + ... role=ModelMessageRoleType.HUMAN, content="Who are you" + ... ), + ... ] + >>> converter = DefaultMessageConverter() + >>> model_metadata = ModelMetadata(model="test") + >>> converted_messages = converter.convert_to_no_system_message( + ... messages, model_metadata + ... ) + >>> assert converted_messages == [ + ... ModelMessage( + ... role=ModelMessageRoleType.HUMAN, + ... content="You are a helpful assistant\\nWho are you", + ... ), + ... ] + """ + if not model_metadata or not model_metadata.ext_metadata: + logger.warning("No model metadata, skip message conversion") + return messages + ext_metadata = model_metadata.ext_metadata + system_messages = [] + result_messages = [] + for message in messages: + if message.role == ModelMessageRoleType.SYSTEM: + # Not support system message, append system message to the last user message + system_messages.append(message) + elif message.role in [ + ModelMessageRoleType.HUMAN, + ModelMessageRoleType.AI, + ]: + result_messages.append(message) + prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n" + system_message_str = None + if len(system_messages) > 1: + logger.warning("Your system messages have more than one message") + system_message_str = prompt_sep.join([m.content for m in system_messages]) + elif len(system_messages) == 1: + system_message_str = system_messages[0].content + + if system_message_str and result_messages: + # Not support system messages, merge system messages to the last user message + result_messages[-1].content = ( + system_message_str + prompt_sep + result_messages[-1].content + ) + return result_messages + + def move_last_user_message_to_end( + self, messages: List[ModelMessage] + ) -> List[ModelMessage]: + """Move the last user message to the end of the list. + + Examples: + + >>> from typing import List + >>> from dbgpt.core.interface.message import ( + ... ModelMessage, + ... ModelMessageRoleType, + ... ) + >>> from dbgpt.core.interface.llm import DefaultMessageConverter + >>> messages = [ + ... ModelMessage( + ... role=ModelMessageRoleType.SYSTEM, + ... content="You are a helpful assistant", + ... ), + ... ModelMessage( + ... role=ModelMessageRoleType.HUMAN, content="Who are you" + ... ), + ... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"), + ... ModelMessage( + ... role=ModelMessageRoleType.HUMAN, content="What's your name" + ... ), + ... ModelMessage( + ... role=ModelMessageRoleType.SYSTEM, + ... content="You are a helpful assistant", + ... ), + ... ] + >>> converter = DefaultMessageConverter() + >>> converted_messages = converter.move_last_user_message_to_end(messages) + >>> assert converted_messages == [ + ... ModelMessage( + ... role=ModelMessageRoleType.SYSTEM, + ... content="You are a helpful assistant", + ... ), + ... ModelMessage( + ... role=ModelMessageRoleType.HUMAN, content="Who are you" + ... ), + ... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"), + ... ModelMessage( + ... role=ModelMessageRoleType.SYSTEM, + ... content="You are a helpful assistant", + ... ), + ... ModelMessage( + ... role=ModelMessageRoleType.HUMAN, content="What's your name" + ... ), + ... ] + + Args: + messages(List[ModelMessage]): The messages. + + Returns: + List[ModelMessage]: The converted messages. + """ + last_user_input_index = None + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == ModelMessageRoleType.HUMAN: + last_user_input_index = i + break + if last_user_input_index is not None: + last_user_input = messages.pop(last_user_input_index) + messages.append(last_user_input) + return messages @PublicAPI(stability="beta") class LLMClient(ABC): """An abstract class for LLM client.""" + # Cache the model metadata for 60 seconds + _MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60) + + @property + def cache(self) -> collections.abc.MutableMapping: + """The cache object to cache the model metadata. + + You can override this property to use your own cache object. + Returns: + collections.abc.MutableMapping: The cache object. + """ + return self._MODEL_CACHE_ + @abstractmethod - async def generate(self, request: ModelRequest) -> ModelOutput: + async def generate( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelOutput: """Generate a response for a given model request. + Sometimes, different LLMs may have different message formats, + you can use the message converter to convert the messages to the format of the LLM. + Args: request(ModelRequest): The model request. + message_converter(MessageConverter): The message converter. Returns: ModelOutput: The model output. @@ -315,12 +658,18 @@ async def generate(self, request: ModelRequest) -> ModelOutput: @abstractmethod async def generate_stream( - self, request: ModelRequest + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, ) -> AsyncIterator[ModelOutput]: """Generate a stream of responses for a given model request. + Sometimes, different LLMs may have different message formats, + you can use the message converter to convert the messages to the format of the LLM. + Args: request(ModelRequest): The model request. + message_converter(MessageConverter): The message converter. Returns: AsyncIterator[ModelOutput]: The model output stream. @@ -345,3 +694,65 @@ async def count_token(self, model: str, prompt: str) -> int: Returns: int: The number of tokens. """ + + async def covert_message( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelRequest: + """Covert the message. + If no message converter is provided, the original request will be returned. + + Args: + request(ModelRequest): The model request. + message_converter(MessageConverter): The message converter. + + Returns: + ModelRequest: The converted model request. + """ + if not message_converter: + return request + new_request = request.copy() + model_metadata = await self.get_model_metadata(request.model) + new_messages = message_converter.convert(request.messages, model_metadata) + new_request.messages = new_messages + return new_request + + async def cached_models(self) -> List[ModelMetadata]: + """Get all the models from the cache or the llm server. + + If the model metadata is not in the cache, it will be fetched from the llm server. + + Returns: + List[ModelMetadata]: A list of model metadata. + """ + key = "____$llm_client_models$____" + if key not in self.cache: + models = await self.models() + self.cache[key] = models + for model in models: + model_metadata_key = ( + f"____$llm_client_models_metadata_{model.model}$____" + ) + self.cache[model_metadata_key] = model + return self.cache[key] + + async def get_model_metadata(self, model: str) -> ModelMetadata: + """Get the model metadata. + + Args: + model(str): The model name. + + Returns: + ModelMetadata: The model metadata. + + Raises: + ValueError: If the model is not found. + """ + model_metadata_key = f"____$llm_client_models_metadata_{model}$____" + if model_metadata_key not in self.cache: + await self.cached_models() + model_metadata = self.cache.get(model_metadata_key) + if not model_metadata: + raise ValueError(f"Model {model} not found") + return model_metadata diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 6b0204817..cded3f18a 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -5,7 +5,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Union from dbgpt._private.pydantic import BaseModel, Field -from dbgpt.core.awel import MapOperator from dbgpt.core.interface.storage import ( InMemoryStorage, ResourceIdentifier, @@ -114,6 +113,50 @@ class ModelMessage(BaseModel): content: str round_index: Optional[int] = 0 + @property + def pass_to_model(self) -> bool: + """Whether the message will be passed to the model + + The view message will not be passed to the model + + Returns: + bool: Whether the message will be passed to the model + """ + return self.role in [ + ModelMessageRoleType.SYSTEM, + ModelMessageRoleType.HUMAN, + ModelMessageRoleType.AI, + ] + + @staticmethod + def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]: + result = [] + for message in messages: + content, round_index = message.content, message.round_index + if isinstance(message, HumanMessage): + result.append( + ModelMessage( + role=ModelMessageRoleType.HUMAN, + content=content, + round_index=round_index, + ) + ) + elif isinstance(message, AIMessage): + result.append( + ModelMessage( + role=ModelMessageRoleType.AI, + content=content, + round_index=round_index, + ) + ) + elif isinstance(message, SystemMessage): + result.append( + ModelMessage( + role=ModelMessageRoleType.SYSTEM, content=message.content + ) + ) + return result + @staticmethod def from_openai_messages( messages: Union[str, List[Dict[str, str]]] @@ -142,9 +185,15 @@ def from_openai_messages( return result @staticmethod - def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]: + def to_openai_messages( + messages: List["ModelMessage"], convert_to_compatible_format: bool = False + ) -> List[Dict[str, str]]: """Convert to OpenAI message format and hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating) + + Args: + messages (List["ModelMessage"]): The model messages + convert_to_compatible_format (bool): Whether to convert to compatible format """ history = [] # Add history conversation @@ -157,15 +206,16 @@ def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]: history.append({"role": "assistant", "content": message.content}) else: pass - # Move the last user's information to the end - last_user_input_index = None - for i in range(len(history) - 1, -1, -1): - if history[i]["role"] == "user": - last_user_input_index = i - break - if last_user_input_index: - last_user_input = history.pop(last_user_input_index) - history.append(last_user_input) + if convert_to_compatible_format: + # Move the last user's information to the end + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i + break + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) + history.append(last_user_input) return history @staticmethod @@ -189,8 +239,8 @@ def get_printable_message(messages: List["ModelMessage"]) -> str: return str_msg -_SingleRoundMessage = List[ModelMessage] -_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]] +_SingleRoundMessage = List[BaseMessage] +_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]] def _message_to_dict(message: BaseMessage) -> Dict: @@ -338,7 +388,8 @@ def start_new_round(self) -> None: """Start a new round of conversation Example: - >>> conversation = OnceConversation() + + >>> conversation = OnceConversation("chat_normal") >>> # The chat order will be 0, then we start a new round of conversation >>> assert conversation.chat_order == 0 >>> conversation.start_new_round() @@ -585,6 +636,28 @@ def get_model_messages(self) -> List[ModelMessage]: ) return messages + def get_history_message( + self, include_system_message: bool = False + ) -> List[BaseMessage]: + """Get the history message + + Not include the system messages. + + Args: + include_system_message (bool): Whether to include the system message + + Returns: + List[BaseMessage]: The history messages + """ + messages = [] + for message in self.messages: + if message.pass_to_model: + if include_system_message: + messages.append(message) + elif message.type != "system": + messages.append(message) + return messages + class ConversationIdentifier(ResourceIdentifier): """Conversation identifier""" diff --git a/dbgpt/core/interface/operator/composer_operator.py b/dbgpt/core/interface/operator/composer_operator.py new file mode 100644 index 000000000..7c0093777 --- /dev/null +++ b/dbgpt/core/interface/operator/composer_operator.py @@ -0,0 +1,114 @@ +import dataclasses +from typing import Any, Dict, List, Optional + +from dbgpt.core import ( + ChatPromptTemplate, + MessageStorageItem, + ModelMessage, + ModelRequest, + StorageConversation, + StorageInterface, +) +from dbgpt.core.awel import ( + DAG, + BaseOperator, + InputOperator, + JoinOperator, + MapOperator, + SimpleCallDataInputSource, +) +from dbgpt.core.interface.operator.prompt_operator import HistoryPromptBuilderOperator + +from .message_operator import ( + BufferedConversationMapperOperator, + ChatHistoryLoadType, + PreChatHistoryLoadOperator, +) + + +@dataclasses.dataclass +class ChatComposerInput: + """The composer input.""" + + prompt_dict: Dict[str, Any] + model_dict: Dict[str, Any] + context: ChatHistoryLoadType + + +class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequest]): + """The chat history prompt composer operator. + + For simple use, you can use this operator to compose the chat history prompt. + """ + + def __init__( + self, + prompt_template: ChatPromptTemplate, + history_key: str = "chat_history", + last_k_round: int = 2, + storage: Optional[StorageInterface[StorageConversation, Any]] = None, + message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self._prompt_template = prompt_template + self._history_key = history_key + self._last_k_round = last_k_round + self._storage = storage + self._message_storage = message_storage + self._sub_compose_dag = self._build_composer_dag() + + async def map(self, input_value: ChatComposerInput) -> ModelRequest: + end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] + # Sub dag, use the same dag context in the parent dag + return await end_node.call( + call_data={"data": input_value}, dag_ctx=self.current_dag_context + ) + + def _build_composer_dag(self) -> DAG: + with DAG("dbgpt_awel_chat_history_prompt_composer") as composer_dag: + input_task = InputOperator(input_source=SimpleCallDataInputSource()) + # Load and store chat history, default use InMemoryStorage. + chat_history_load_task = PreChatHistoryLoadOperator( + storage=self._storage, message_storage=self._message_storage + ) + # History transform task, here we keep last 5 round messages + history_transform_task = BufferedConversationMapperOperator( + last_k_round=self._last_k_round + ) + history_prompt_build_task = HistoryPromptBuilderOperator( + prompt=self._prompt_template, history_key=self._history_key + ) + model_request_build_task = JoinOperator(self._build_model_request) + + # Build composer dag + ( + input_task + >> MapOperator(lambda x: x.context) + >> chat_history_load_task + >> history_transform_task + >> history_prompt_build_task + ) + ( + input_task + >> MapOperator(lambda x: x.prompt_dict) + >> history_prompt_build_task + ) + + history_prompt_build_task >> model_request_build_task + ( + input_task + >> MapOperator(lambda x: x.model_dict) + >> model_request_build_task + ) + + return composer_dag + + def _build_model_request( + self, messages: List[ModelMessage], model_dict: Dict[str, Any] + ) -> ModelRequest: + return ModelRequest.build_request(messages=messages, **model_dict) + + async def after_dag_end(self): + # Should call after_dag_end() of sub dag + await self._sub_compose_dag._after_dag_end() diff --git a/dbgpt/core/interface/operator/llm_operator.py b/dbgpt/core/interface/operator/llm_operator.py index fc117ddc5..5570aa1d6 100644 --- a/dbgpt/core/interface/operator/llm_operator.py +++ b/dbgpt/core/interface/operator/llm_operator.py @@ -1,11 +1,12 @@ import dataclasses from abc import ABC -from typing import Any, AsyncIterator, Dict, Optional, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Union from dbgpt._private.pydantic import BaseModel from dbgpt.core.awel import ( BranchFunc, BranchOperator, + DAGContext, MapOperator, StreamifyAbsOperator, ) @@ -22,20 +23,30 @@ str, Dict[str, Any], BaseModel, + ModelMessage, + List[ModelMessage], ] -class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC): +class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC): + """Build the model request from the input value.""" + def __init__(self, model: Optional[str] = None, **kwargs): self._model = model super().__init__(**kwargs) async def map(self, input_value: RequestInput) -> ModelRequest: req_dict = {} + if not input_value: + raise ValueError("input_value is not set") if isinstance(input_value, str): req_dict = {"messages": [ModelMessage.build_human_message(input_value)]} elif isinstance(input_value, dict): req_dict = input_value + elif isinstance(input_value, ModelMessage): + req_dict = {"messages": [input_value]} + elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage): + req_dict = {"messages": input_value} elif dataclasses.is_dataclass(input_value): req_dict = dataclasses.asdict(input_value) elif isinstance(input_value, BaseModel): @@ -76,6 +87,7 @@ class BaseLLM: """The abstract operator for a LLM.""" SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name" + SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output" def __init__(self, llm_client: Optional[LLMClient] = None): self._llm_client = llm_client @@ -87,8 +99,16 @@ def llm_client(self) -> LLMClient: raise ValueError("llm_client is not set") return self._llm_client + async def save_model_output( + self, current_dag_context: DAGContext, model_output: ModelOutput + ) -> None: + """Save the model output to the share data.""" + await current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output + ) + -class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): +class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): """The operator for a LLM. Args: @@ -105,10 +125,12 @@ async def map(self, request: ModelRequest) -> ModelOutput: await self.current_dag_context.save_to_share_data( self.SHARE_DATA_KEY_MODEL_NAME, request.model ) - return await self.llm_client.generate(request) + model_output = await self.llm_client.generate(request) + await self.save_model_output(self.current_dag_context, model_output) + return model_output -class StreamingLLMOperator( +class BaseStreamingLLMOperator( BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC ): """The streaming operator for a LLM. @@ -127,8 +149,12 @@ async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]: await self.current_dag_context.save_to_share_data( self.SHARE_DATA_KEY_MODEL_NAME, request.model ) + model_output = None async for output in self.llm_client.generate_stream(request): + model_output = output yield output + if model_output: + await self.save_model_output(self.current_dag_context, model_output) class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]): diff --git a/dbgpt/core/interface/operator/message_operator.py b/dbgpt/core/interface/operator/message_operator.py index 1a995da2b..f6eb1b24b 100644 --- a/dbgpt/core/interface/operator/message_operator.py +++ b/dbgpt/core/interface/operator/message_operator.py @@ -1,19 +1,17 @@ import uuid -from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, List, Optional +from abc import ABC +from typing import Any, Dict, List, Optional, Union from dbgpt.core import ( MessageStorageItem, ModelMessage, ModelMessageRoleType, - ModelOutput, - ModelRequest, ModelRequestContext, StorageConversation, StorageInterface, ) -from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator -from dbgpt.core.interface.message import _MultiRoundMessageMapper +from dbgpt.core.awel import BaseOperator, MapOperator +from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper class BaseConversationOperator(BaseOperator, ABC): @@ -21,32 +19,41 @@ class BaseConversationOperator(BaseOperator, ABC): SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation" SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request" + SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT = "share_data_key_model_request_context" + + _check_storage: bool = True def __init__( self, storage: Optional[StorageInterface[StorageConversation, Any]] = None, message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + check_storage: bool = True, **kwargs, ): + self._check_storage = check_storage super().__init__(**kwargs) self._storage = storage self._message_storage = message_storage @property - def storage(self) -> StorageInterface[StorageConversation, Any]: + def storage(self) -> Optional[StorageInterface[StorageConversation, Any]]: """Return the LLM client.""" if not self._storage: - raise ValueError("Storage is not set") + if self._check_storage: + raise ValueError("Storage is not set") + return None return self._storage @property - def message_storage(self) -> StorageInterface[MessageStorageItem, Any]: + def message_storage(self) -> Optional[StorageInterface[MessageStorageItem, Any]]: """Return the LLM client.""" if not self._message_storage: - raise ValueError("Message storage is not set") + if self._check_storage: + raise ValueError("Message storage is not set") + return None return self._message_storage - async def get_storage_conversation(self) -> StorageConversation: + async def get_storage_conversation(self) -> Optional[StorageConversation]: """Get the storage conversation from share data. Returns: @@ -58,280 +65,170 @@ async def get_storage_conversation(self) -> StorageConversation: ) ) if not storage_conv: - raise ValueError("Storage conversation is not set") + if self._check_storage: + raise ValueError("Storage conversation is not set") + return None return storage_conv - async def get_model_request(self) -> ModelRequest: - """Get the model request from share data. + def check_messages(self, messages: List[ModelMessage]) -> None: + """Check the messages. - Returns: - ModelRequest: The model request. + Args: + messages (List[ModelMessage]): The messages. + + Raises: + ValueError: If the messages is empty. """ - model_request: ModelRequest = ( - await self.current_dag_context.get_from_share_data( - self.SHARE_DATA_KEY_MODEL_REQUEST - ) - ) - if not model_request: - raise ValueError("Model request is not set") - return model_request + if not messages: + raise ValueError("Input messages is empty") + for message in messages: + if message.role not in [ + ModelMessageRoleType.HUMAN, + ModelMessageRoleType.SYSTEM, + ]: + raise ValueError(f"Message role {message.role} is not supported") -class PreConversationOperator( - BaseConversationOperator, MapOperator[ModelRequest, ModelRequest] +ChatHistoryLoadType = Union[ModelRequestContext, Dict[str, Any]] + + +class PreChatHistoryLoadOperator( + BaseConversationOperator, MapOperator[ChatHistoryLoadType, List[BaseMessage]] ): """The operator to prepare the storage conversation. In DB-GPT, conversation record and the messages in the conversation are stored in the storage, and they can store in different storage(for high performance). + + This operator just load the conversation and messages from storage. """ def __init__( self, storage: Optional[StorageInterface[StorageConversation, Any]] = None, message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + include_system_message: bool = False, **kwargs, ): super().__init__(storage=storage, message_storage=message_storage) MapOperator.__init__(self, **kwargs) + self._include_system_message = include_system_message - async def map(self, input_value: ModelRequest) -> ModelRequest: + async def map(self, input_value: ChatHistoryLoadType) -> List[BaseMessage]: """Map the input value to a ModelRequest. Args: - input_value (ModelRequest): The input value. + input_value (ChatHistoryLoadType): The input value. Returns: - ModelRequest: The mapped ModelRequest. + List[BaseMessage]: The messages stored in the storage. """ - if input_value.context is None: - input_value.context = ModelRequestContext() - if not input_value.context.conv_uid: - input_value.context.conv_uid = str(uuid.uuid4()) - if not input_value.context.extra: - input_value.context.extra = {} + if not input_value: + raise ValueError("Model request context can't be None") + if isinstance(input_value, dict): + input_value = ModelRequestContext(**input_value) + if not input_value.conv_uid: + input_value.conv_uid = str(uuid.uuid4()) + if not input_value.extra: + input_value.extra = {} - chat_mode = input_value.context.chat_mode + chat_mode = input_value.chat_mode # Create a new storage conversation, this will load the conversation from storage, so we must do this async storage_conv: StorageConversation = await self.blocking_func_to_async( StorageConversation, - conv_uid=input_value.context.conv_uid, + conv_uid=input_value.conv_uid, chat_mode=chat_mode, - user_name=input_value.context.user_name, - sys_code=input_value.context.sys_code, + user_name=input_value.user_name, + sys_code=input_value.sys_code, conv_storage=self.storage, message_storage=self.message_storage, ) - input_messages = input_value.get_messages() - await self.save_to_storage(storage_conv, input_messages) - # Get all messages from current storage conversation, and overwrite the input value - messages: List[ModelMessage] = storage_conv.get_model_messages() - input_value.messages = messages # Save the storage conversation to share data, for the child operators await self.current_dag_context.save_to_share_data( self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv ) await self.current_dag_context.save_to_share_data( - self.SHARE_DATA_KEY_MODEL_REQUEST, input_value + self.SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT, input_value ) - return input_value - - async def save_to_storage( - self, storage_conv: StorageConversation, input_messages: List[ModelMessage] - ) -> None: - """Save the messages to storage. - - Args: - storage_conv (StorageConversation): The storage conversation. - input_messages (List[ModelMessage]): The input messages. - """ - # check first - self.check_messages(input_messages) - storage_conv.start_new_round() - for message in input_messages: - if message.role == ModelMessageRoleType.HUMAN: - storage_conv.add_user_message(message.content) - else: - storage_conv.add_system_message(message.content) - - def check_messages(self, messages: List[ModelMessage]) -> None: - """Check the messages. - - Args: - messages (List[ModelMessage]): The messages. - - Raises: - ValueError: If the messages is empty. - """ - if not messages: - raise ValueError("Input messages is empty") - for message in messages: - if message.role not in [ - ModelMessageRoleType.HUMAN, - ModelMessageRoleType.SYSTEM, - ]: - raise ValueError(f"Message role {message.role} is not supported") - - async def after_dag_end(self): - """The callback after DAG end""" - # Save the storage conversation to storage after the whole DAG finished - storage_conv: StorageConversation = await self.get_storage_conversation() - # TODO dont save if the conversation has some internal error - storage_conv.end_current_round() - - -class PostConversationOperator( - BaseConversationOperator, MapOperator[ModelOutput, ModelOutput] -): - def __init__(self, **kwargs): - MapOperator.__init__(self, **kwargs) - - async def map(self, input_value: ModelOutput) -> ModelOutput: - """Map the input value to a ModelOutput. - - Args: - input_value (ModelOutput): The input value. - - Returns: - ModelOutput: The mapped ModelOutput. - """ - # Get the storage conversation from share data - storage_conv: StorageConversation = await self.get_storage_conversation() - storage_conv.add_ai_message(input_value.text) - return input_value - - -class PostStreamingConversationOperator( - BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput] -): - def __init__(self, **kwargs): - TransformStreamAbsOperator.__init__(self, **kwargs) - - async def transform_stream( - self, input_value: AsyncIterator[ModelOutput] - ) -> ModelOutput: - """Transform the input value to a ModelOutput. - - Args: - input_value (ModelOutput): The input value. - - Returns: - ModelOutput: The transformed ModelOutput. - """ - full_text = "" - async for model_output in input_value: - # Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text - full_text = model_output.text - yield model_output - # Get the storage conversation from share data - storage_conv: StorageConversation = await self.get_storage_conversation() - storage_conv.add_ai_message(full_text) + # Get history messages from storage + history_messages: List[BaseMessage] = storage_conv.get_history_message( + include_system_message=self._include_system_message + ) + return history_messages class ConversationMapperOperator( - BaseConversationOperator, MapOperator[ModelRequest, ModelRequest] + BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]] ): def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs): MapOperator.__init__(self, **kwargs) self._message_mapper = message_mapper - async def map(self, input_value: ModelRequest) -> ModelRequest: - """Map the input value to a ModelRequest. + async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]: + return self.map_messages(input_value) - Args: - input_value (ModelRequest): The input value. - - Returns: - ModelRequest: The mapped ModelRequest. - """ - input_value = input_value.copy() - messages: List[ModelMessage] = self.map_messages(input_value.messages) - # Overwrite the input value - input_value.messages = messages - return input_value - - def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]: - """Map the input messages to a list of ModelMessage. - - Args: - messages (List[ModelMessage]): The input messages. - - Returns: - List[ModelMessage]: The mapped ModelMessage. - """ - messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round( + def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: + messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round( messages ) message_mapper = self._message_mapper or self.map_multi_round_messages return message_mapper(messages_by_round) def map_multi_round_messages( - self, messages_by_round: List[List[ModelMessage]] - ) -> List[ModelMessage]: - """Map multi round messages to a list of ModelMessage + self, messages_by_round: List[List[BaseMessage]] + ) -> List[BaseMessage]: + """Map multi round messages to a list of BaseMessage. - By default, just merge all multi round messages to a list of ModelMessage according origin order. + By default, just merge all multi round messages to a list of BaseMessage according origin order. And you can overwrite this method to implement your own logic. Examples: - Merge multi round messages to a list of ModelMessage according origin order. - - .. code-block:: python - - import asyncio - from dbgpt.core.operator import ConversationMapperOperator - - messages_by_round = [ - [ - ModelMessage(role="human", content="Hi", round_index=1), - ModelMessage(role="ai", content="Hello!", round_index=1), - ], - [ - ModelMessage(role="system", content="Error 404", round_index=2), - ModelMessage( - role="human", content="What's the error?", round_index=2 - ), - ModelMessage(role="ai", content="Just a joke.", round_index=2), - ], - [ - ModelMessage(role="human", content="Funny!", round_index=3), - ], - ] - operator = ConversationMapperOperator() - messages = operator.map_multi_round_messages(messages_by_round) - assert messages == [ - ModelMessage(role="human", content="Hi", round_index=1), - ModelMessage(role="ai", content="Hello!", round_index=1), - ModelMessage(role="system", content="Error 404", round_index=2), - ModelMessage( - role="human", content="What's the error?", round_index=2 - ), - ModelMessage(role="ai", content="Just a joke.", round_index=2), - ModelMessage(role="human", content="Funny!", round_index=3), - ] - - Map multi round messages to a list of ModelMessage just keep the last one round. - - .. code-block:: python - - class MyMapper(ConversationMapperOperator): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def map_multi_round_messages( - self, messages_by_round: List[List[ModelMessage]] - ) -> List[ModelMessage]: - return messages_by_round[-1] - - - operator = MyMapper() - messages = operator.map_multi_round_messages(messages_by_round) - assert messages == [ - ModelMessage(role="human", content="Funny!", round_index=3), - ] + Merge multi round messages to a list of BaseMessage according origin order. + + >>> from dbgpt.core.interface.message import ( + ... AIMessage, + ... HumanMessage, + ... SystemMessage, + ... ) + >>> messages_by_round = [ + ... [ + ... HumanMessage(content="Hi", round_index=1), + ... AIMessage(content="Hello!", round_index=1), + ... ], + ... [ + ... HumanMessage(content="What's the error?", round_index=2), + ... AIMessage(content="Just a joke.", round_index=2), + ... ], + ... ] + >>> operator = ConversationMapperOperator() + >>> messages = operator.map_multi_round_messages(messages_by_round) + >>> assert messages == [ + ... HumanMessage(content="Hi", round_index=1), + ... AIMessage(content="Hello!", round_index=1), + ... HumanMessage(content="What's the error?", round_index=2), + ... AIMessage(content="Just a joke.", round_index=2), + ... ] + + Map multi round messages to a list of BaseMessage just keep the last one round. + + >>> class MyMapper(ConversationMapperOperator): + ... def __init__(self, **kwargs): + ... super().__init__(**kwargs) + ... + ... def map_multi_round_messages( + ... self, messages_by_round: List[List[BaseMessage]] + ... ) -> List[BaseMessage]: + ... return messages_by_round[-1] + ... + >>> operator = MyMapper() + >>> messages = operator.map_multi_round_messages(messages_by_round) + >>> assert messages == [ + ... HumanMessage(content="What's the error?", round_index=2), + ... AIMessage(content="Just a joke.", round_index=2), + ... ] Args: """ @@ -340,17 +237,17 @@ def map_multi_round_messages( return sum(messages_by_round, []) def _split_messages_by_round( - self, messages: List[ModelMessage] - ) -> List[List[ModelMessage]]: + self, messages: List[BaseMessage] + ) -> List[List[BaseMessage]]: """Split the messages by round index. Args: - messages (List[ModelMessage]): The input messages. + messages (List[BaseMessage]): The messages. Returns: - List[List[ModelMessage]]: The split messages. + List[List[BaseMessage]]: The messages split by round. """ - messages_by_round: List[List[ModelMessage]] = [] + messages_by_round: List[List[BaseMessage]] = [] last_round_index = 0 for message in messages: if not message.round_index: @@ -366,7 +263,7 @@ def _split_messages_by_round( class BufferedConversationMapperOperator(ConversationMapperOperator): """The buffered conversation mapper operator. - This Operator must be used after the PreConversationOperator, + This Operator must be used after the PreChatHistoryLoadOperator, and it will map the messages in the storage conversation. Examples: @@ -419,8 +316,8 @@ def __init__( if message_mapper: def new_message_mapper( - messages_by_round: List[List[ModelMessage]], - ) -> List[ModelMessage]: + messages_by_round: List[List[BaseMessage]], + ) -> List[BaseMessage]: # Apply keep k round messages first, then apply the custom message mapper messages_by_round = self._keep_last_round_messages(messages_by_round) return message_mapper(messages_by_round) @@ -428,23 +325,23 @@ def new_message_mapper( else: def new_message_mapper( - messages_by_round: List[List[ModelMessage]], - ) -> List[ModelMessage]: + messages_by_round: List[List[BaseMessage]], + ) -> List[BaseMessage]: messages_by_round = self._keep_last_round_messages(messages_by_round) return sum(messages_by_round, []) super().__init__(new_message_mapper, **kwargs) def _keep_last_round_messages( - self, messages_by_round: List[List[ModelMessage]] - ) -> List[List[ModelMessage]]: + self, messages_by_round: List[List[BaseMessage]] + ) -> List[List[BaseMessage]]: """Keep the last k round messages. Args: - messages_by_round (List[List[ModelMessage]]): The messages by round. + messages_by_round (List[List[BaseMessage]]): The messages by round. Returns: - List[List[ModelMessage]]: The latest round messages. + List[List[BaseMessage]]: The latest round messages. """ index = self._last_k_round + 1 return messages_by_round[-index:] diff --git a/dbgpt/core/interface/operator/prompt_operator.py b/dbgpt/core/interface/operator/prompt_operator.py new file mode 100644 index 000000000..18c727d14 --- /dev/null +++ b/dbgpt/core/interface/operator/prompt_operator.py @@ -0,0 +1,255 @@ +from abc import ABC +from typing import Any, Dict, List, Optional, Union + +from dbgpt.core import ( + BasePromptTemplate, + ChatPromptTemplate, + ModelMessage, + ModelMessageRoleType, + ModelOutput, + StorageConversation, +) +from dbgpt.core.awel import JoinOperator, MapOperator +from dbgpt.core.interface.message import BaseMessage +from dbgpt.core.interface.operator.llm_operator import BaseLLM +from dbgpt.core.interface.operator.message_operator import BaseConversationOperator +from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType +from dbgpt.util.function_utils import rearrange_args_by_type + + +class BasePromptBuilderOperator(BaseConversationOperator, ABC): + """The base prompt builder operator.""" + + def __init__(self, check_storage: bool, **kwargs): + super().__init__(check_storage=check_storage, **kwargs) + + async def format_prompt( + self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any] + ) -> List[ModelMessage]: + """Format the prompt. + + Args: + prompt (ChatPromptTemplate): The prompt. + prompt_dict (Dict[str, Any]): The prompt dict. + + Returns: + List[ModelMessage]: The formatted prompt. + """ + kwargs = {} + kwargs.update(prompt_dict) + pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables} + messages = prompt.format_messages(**pass_kwargs) + messages = ModelMessage.from_base_messages(messages) + # Start new round conversation, and save user message to storage + await self.start_new_round_conv(messages) + return messages + + async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: + """Start a new round conversation. + + Args: + messages (List[ModelMessage]): The messages. + """ + + lass_user_message = None + for message in messages[::-1]: + if message.role == ModelMessageRoleType.HUMAN: + lass_user_message = message.content + break + if not lass_user_message: + raise ValueError("No user message") + storage_conv: StorageConversation = await self.get_storage_conversation() + if not storage_conv: + return + # Start new round + storage_conv.start_new_round() + storage_conv.add_user_message(lass_user_message) + + async def after_dag_end(self): + """The callback after DAG end""" + # TODO remove this to start_new_round() + # Save the storage conversation to storage after the whole DAG finished + storage_conv: StorageConversation = await self.get_storage_conversation() + if not storage_conv: + return + model_output: ModelOutput = await self.current_dag_context.get_from_share_data( + BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT + ) + if model_output: + # Save model output message to storage + storage_conv.add_ai_message(model_output.text) + # End current conversation round and flush to storage + storage_conv.end_current_round() + + +PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str] + + +class PromptBuilderOperator( + BasePromptBuilderOperator, MapOperator[Dict[str, Any], List[ModelMessage]] +): + """The operator to build the prompt with static prompt. + + Examples: + + .. code-block:: python + + import asyncio + from dbgpt.core.awel import DAG + from dbgpt.core import ( + ModelMessage, + HumanMessage, + SystemMessage, + HumanPromptTemplate, + SystemPromptTemplate, + ChatPromptTemplate, + ) + from dbgpt.core.operator import PromptBuilderOperator + + with DAG("prompt_test") as dag: + str_prompt = PromptBuilderOperator( + "Please write a {dialect} SQL count the length of a field" + ) + tp_prompt = PromptBuilderOperator( + HumanPromptTemplate.from_template( + "Please write a {dialect} SQL count the length of a field" + ) + ) + chat_prompt = PromptBuilderOperator( + ChatPromptTemplate( + messages=[ + HumanPromptTemplate.from_template( + "Please write a {dialect} SQL count the length of a field" + ) + ] + ) + ) + with_sys_prompt = PromptBuilderOperator( + ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template( + "You are a {dialect} SQL expert" + ), + HumanPromptTemplate.from_template( + "Please write a {dialect} SQL count the length of a field" + ), + ], + ) + ) + + single_input = {"data": {"dialect": "mysql"}} + single_expected_messages = [ + ModelMessage( + content="Please write a mysql SQL count the length of a field", + role="human", + ) + ] + with_sys_expected_messages = [ + ModelMessage(content="You are a mysql SQL expert", role="system"), + ModelMessage( + content="Please write a mysql SQL count the length of a field", + role="human", + ), + ] + assert ( + asyncio.run(str_prompt.call(call_data=single_input)) + == single_expected_messages + ) + assert ( + asyncio.run(tp_prompt.call(call_data=single_input)) + == single_expected_messages + ) + assert ( + asyncio.run(chat_prompt.call(call_data=single_input)) + == single_expected_messages + ) + assert ( + asyncio.run(with_sys_prompt.call(call_data=single_input)) + == with_sys_expected_messages + ) + + """ + + def __init__(self, prompt: PromptTemplateType, **kwargs): + if isinstance(prompt, str): + prompt = ChatPromptTemplate( + messages=[HumanPromptTemplate.from_template(prompt)] + ) + elif isinstance(prompt, BasePromptTemplate) and not isinstance( + prompt, ChatPromptTemplate + ): + prompt = ChatPromptTemplate( + messages=[HumanPromptTemplate.from_template(prompt.template)] + ) + elif isinstance(prompt, MessageType): + prompt = ChatPromptTemplate(messages=[prompt]) + self._prompt = prompt + + super().__init__(check_storage=False, **kwargs) + MapOperator.__init__(self, map_function=self.merge_prompt, **kwargs) + + @rearrange_args_by_type + async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]: + return await self.format_prompt(self._prompt, prompt_dict) + + +class DynamicPromptBuilderOperator( + BasePromptBuilderOperator, JoinOperator[List[ModelMessage]] +): + """The operator to build the prompt with dynamic prompt. + + The prompt template is dynamic, and it created by parent operator. + """ + + def __init__(self, **kwargs): + super().__init__(check_storage=False, **kwargs) + JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs) + + @rearrange_args_by_type + async def merge_prompt( + self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any] + ) -> List[ModelMessage]: + return await self.format_prompt(prompt, prompt_dict) + + +class HistoryPromptBuilderOperator( + BasePromptBuilderOperator, JoinOperator[List[ModelMessage]] +): + def __init__( + self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs + ): + self._prompt = prompt + self._history_key = history_key + + JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) + + @rearrange_args_by_type + async def merge_history( + self, history: List[BaseMessage], prompt_dict: Dict[str, Any] + ) -> List[ModelMessage]: + prompt_dict[self._history_key] = history + return await self.format_prompt(self._prompt, prompt_dict) + + +class HistoryDynamicPromptBuilderOperator( + BasePromptBuilderOperator, JoinOperator[List[ModelMessage]] +): + """The operator to build the prompt with dynamic prompt. + + The prompt template is dynamic, and it created by parent operator. + """ + + def __init__(self, history_key: Optional[str] = None, **kwargs): + self._history_key = history_key + + JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) + + @rearrange_args_by_type + async def merge_history( + self, + prompt: ChatPromptTemplate, + history: List[BaseMessage], + prompt_dict: Dict[str, Any], + ) -> List[ModelMessage]: + prompt_dict[self._history_key] = history + return await self.format_prompt(prompt, prompt_dict) diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index f584eda13..21fda87f6 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -1,11 +1,14 @@ +from __future__ import annotations + import dataclasses import json from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional +from string import Formatter +from typing import Any, Callable, Dict, List, Optional, Set, Union -from dbgpt._private.pydantic import BaseModel +from dbgpt._private.pydantic import BaseModel, root_validator from dbgpt.core._private.example_base import ExampleSelector -from dbgpt.core.awel import MapOperator +from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage from dbgpt.core.interface.output_parser import BaseOutputParser from dbgpt.core.interface.storage import ( InMemoryStorage, @@ -38,15 +41,40 @@ def _jinja2_formatter(template: str, **kwargs: Any) -> str: } -class PromptTemplate(BaseModel, ABC): +class BasePromptTemplate(BaseModel): input_variables: List[str] """A list of the names of the variables the prompt template expects.""" + + template: Optional[str] + """The prompt template.""" + + template_format: Optional[str] = "f-string" + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs.""" + if self.template: + return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)( + self.template, **kwargs + ) + + @classmethod + def from_template( + cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any + ) -> BasePromptTemplate: + """Create a prompt template from a template string.""" + input_variables = get_template_vars(template, template_format) + return cls( + template=template, + input_variables=input_variables, + template_format=template_format, + **kwargs, + ) + + +class PromptTemplate(BasePromptTemplate): template_scene: Optional[str] template_define: Optional[str] """this template define""" - template: Optional[str] - """The prompt template.""" - template_format: str = "f-string" """strict template will check template args""" template_is_strict: bool = True """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" @@ -86,12 +114,114 @@ def format(self, **kwargs: Any) -> str: self.template_is_strict )(self.template, **kwargs) - @staticmethod - def from_template(template: str) -> "PromptTemplateOperator": + +class BaseChatPromptTemplate(BaseModel, ABC): + prompt: BasePromptTemplate + + @property + def input_variables(self) -> List[str]: + """A list of the names of the variables the prompt template expects.""" + return self.prompt.input_variables + + @abstractmethod + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format the prompt with the inputs.""" + + @classmethod + def from_template( + cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any + ) -> BaseChatPromptTemplate: """Create a prompt template from a template string.""" - return PromptTemplateOperator( - PromptTemplate(template=template, input_variables=[]) - ) + prompt = BasePromptTemplate.from_template(template, template_format) + return cls(prompt=prompt, **kwargs) + + +class SystemPromptTemplate(BaseChatPromptTemplate): + """The system prompt template.""" + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + content = self.prompt.format(**kwargs) + return [SystemMessage(content=content)] + + +class HumanPromptTemplate(BaseChatPromptTemplate): + """The human prompt template.""" + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + content = self.prompt.format(**kwargs) + return [HumanMessage(content=content)] + + +class MessagesPlaceholder(BaseChatPromptTemplate): + """The messages placeholder template. + + Mostly used for the chat history. + """ + + variable_name: str + prompt: BasePromptTemplate = None + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + messages = kwargs.get(self.variable_name, []) + if not isinstance(messages, list): + raise ValueError( + f"Unsupported messages type: {type(messages)}, should be list." + ) + for message in messages: + if not isinstance(message, BaseMessage): + raise ValueError( + f"Unsupported message type: {type(message)}, should be BaseMessage." + ) + return messages + + @property + def input_variables(self) -> List[str]: + """A list of the names of the variables the prompt template expects. + + Returns: + List[str]: The input variables. + """ + return [self.variable_name] + + +MessageType = Union[BaseChatPromptTemplate, BaseMessage] + + +class ChatPromptTemplate(BasePromptTemplate): + messages: List[MessageType] + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format the prompt with the inputs.""" + result_messages = [] + for message in self.messages: + if isinstance(message, BaseMessage): + result_messages.append(message) + elif isinstance(message, BaseChatPromptTemplate): + pass_kwargs = { + k: v for k, v in kwargs.items() if k in message.input_variables + } + result_messages.extend(message.format_messages(**pass_kwargs)) + elif isinstance(message, MessagesPlaceholder): + pass_kwargs = { + k: v for k, v in kwargs.items() if k in message.input_variables + } + result_messages.extend(message.format_messages(**pass_kwargs)) + else: + raise ValueError(f"Unsupported message type: {type(message)}") + return result_messages + + @root_validator(pre=True) + def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Pre-fill the messages.""" + input_variables = values.get("input_variables", {}) + messages = values.get("messages", []) + if not input_variables: + input_variables = set() + for message in messages: + if isinstance(message, BaseChatPromptTemplate): + input_variables.update(message.input_variables) + values["input_variables"] = sorted(input_variables) + return values @dataclasses.dataclass @@ -547,10 +677,36 @@ def delete( self.storage.delete(identifier) -class PromptTemplateOperator(MapOperator[Dict, str]): - def __init__(self, prompt_template: PromptTemplate, **kwargs: Any): - super().__init__(**kwargs) - self._prompt_template = prompt_template +def _get_string_template_vars(template_str: str) -> Set[str]: + """Get template variables from a template string.""" + variables = set() + formatter = Formatter() + + for _, variable_name, _, _ in formatter.parse(template_str): + if variable_name: + variables.add(variable_name) + + return variables + + +def _get_jinja2_template_vars(template_str: str) -> Set[str]: + """Get template variables from a template string.""" + from jinja2 import Environment, meta + + env = Environment() + ast = env.parse(template_str) + variables = meta.find_undeclared_variables(ast) + return variables + - async def map(self, input_value: Dict) -> str: - return self._prompt_template.format(**input_value) +def get_template_vars( + template_str: str, template_format: str = "f-string" +) -> List[str]: + """Get template variables from a template string.""" + if template_format == "f-string": + result = _get_string_template_vars(template_str) + elif template_format == "jinja2": + result = _get_jinja2_template_vars(template_str) + else: + raise ValueError(f"Unsupported template format: {template_format}") + return sorted(result) diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py index 98974e4bf..7221dadb2 100755 --- a/dbgpt/core/interface/tests/test_message.py +++ b/dbgpt/core/interface/tests/test_message.py @@ -413,13 +413,18 @@ def test_to_openai_messages( {"role": "user", "content": human_model_message.content}, ] + +def test_to_openai_messages_convert_to_compatible_format( + human_model_message, ai_model_message, system_model_message +): shuffle_messages = ModelMessage.to_openai_messages( [ system_model_message, human_model_message, human_model_message, ai_model_message, - ] + ], + convert_to_compatible_format=True, ) assert shuffle_messages == [ {"role": "system", "content": system_model_message.content}, diff --git a/dbgpt/core/interface/tests/test_prompt.py b/dbgpt/core/interface/tests/test_prompt.py index e1a449013..3d40627c9 100644 --- a/dbgpt/core/interface/tests/test_prompt.py +++ b/dbgpt/core/interface/tests/test_prompt.py @@ -99,12 +99,6 @@ def test_format_with_response_format(self): formatted_output = prompt.format(response="hello") assert "Response: " in formatted_output - def test_from_template(self): - template_str = "Hello {name}" - prompt = PromptTemplate.from_template(template_str) - assert prompt._prompt_template.template == template_str - assert prompt._prompt_template.input_variables == [] - def test_format_missing_variable(self): template_str = "Hello {name}" prompt = PromptTemplate( diff --git a/dbgpt/core/operator/__init__.py b/dbgpt/core/operator/__init__.py index 0d0287d61..952b89143 100644 --- a/dbgpt/core/operator/__init__.py +++ b/dbgpt/core/operator/__init__.py @@ -1,31 +1,41 @@ +from dbgpt.core.interface.operator.composer_operator import ( + ChatComposerInput, + ChatHistoryPromptComposerOperator, +) from dbgpt.core.interface.operator.llm_operator import ( BaseLLM, + BaseLLMOperator, + BaseStreamingLLMOperator, LLMBranchOperator, - LLMOperator, - RequestBuildOperator, - StreamingLLMOperator, + RequestBuilderOperator, ) from dbgpt.core.interface.operator.message_operator import ( BaseConversationOperator, BufferedConversationMapperOperator, ConversationMapperOperator, - PostConversationOperator, - PostStreamingConversationOperator, - PreConversationOperator, + PreChatHistoryLoadOperator, +) +from dbgpt.core.interface.operator.prompt_operator import ( + DynamicPromptBuilderOperator, + HistoryDynamicPromptBuilderOperator, + HistoryPromptBuilderOperator, + PromptBuilderOperator, ) -from dbgpt.core.interface.prompt import PromptTemplateOperator __ALL__ = [ "BaseLLM", "LLMBranchOperator", - "LLMOperator", - "RequestBuildOperator", - "StreamingLLMOperator", + "BaseLLMOperator", + "RequestBuilderOperator", + "BaseStreamingLLMOperator", "BaseConversationOperator", "BufferedConversationMapperOperator", "ConversationMapperOperator", - "PostConversationOperator", - "PostStreamingConversationOperator", - "PreConversationOperator", - "PromptTemplateOperator", + "PreChatHistoryLoadOperator", + "PromptBuilderOperator", + "DynamicPromptBuilderOperator", + "HistoryPromptBuilderOperator", + "HistoryDynamicPromptBuilderOperator", + "ChatComposerInput", + "ChatHistoryPromptComposerOperator", ] diff --git a/dbgpt/model/__init__.py b/dbgpt/model/__init__.py index e13ec4adc..28054fe9d 100644 --- a/dbgpt/model/__init__.py +++ b/dbgpt/model/__init__.py @@ -1,13 +1,7 @@ from dbgpt.model.cluster.client import DefaultLLMClient -from dbgpt.model.utils.chatgpt_utils import ( - OpenAILLMClient, - OpenAIStreamingOperator, - MixinLLMOperator, -) +from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient __ALL__ = [ "DefaultLLMClient", "OpenAILLMClient", - "OpenAIStreamingOperator", - "MixinLLMOperator", ] diff --git a/dbgpt/model/adapter/base.py b/dbgpt/model/adapter/base.py index a456826a2..e3874393c 100644 --- a/dbgpt/model/adapter/base.py +++ b/dbgpt/model/adapter/base.py @@ -152,7 +152,7 @@ def get_default_message_separator(self) -> str: return "\n" def transform_model_messages( - self, messages: List[ModelMessage] + self, messages: List[ModelMessage], convert_to_compatible_format: bool = False ) -> List[Dict[str, str]]: """Transform the model messages @@ -174,15 +174,19 @@ def transform_model_messages( ] Args: messages (List[ModelMessage]): The model messages + convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False. Returns: List[Dict[str, str]]: The transformed model messages """ logger.info(f"support_system_message: {self.support_system_message}") - if not self.support_system_message: + if not self.support_system_message and convert_to_compatible_format: + # We will not do any transform in the future return self._transform_to_no_system_messages(messages) else: - return ModelMessage.to_openai_messages(messages) + return ModelMessage.to_openai_messages( + messages, convert_to_compatible_format=convert_to_compatible_format + ) def _transform_to_no_system_messages( self, messages: List[ModelMessage] @@ -237,6 +241,7 @@ def get_str_prompt( messages: List[ModelMessage], tokenizer: Any, prompt_template: str = None, + convert_to_compatible_format: bool = False, ) -> Optional[str]: """Get the string prompt from the given parameters and messages @@ -247,6 +252,7 @@ def get_str_prompt( messages (List[ModelMessage]): The model messages tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer prompt_template (str, optional): The prompt template. Defaults to None. + convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False. Returns: Optional[str]: The string prompt @@ -262,6 +268,7 @@ def get_prompt_with_template( model_context: Dict, prompt_template: str = None, ): + convert_to_compatible_format = params.get("convert_to_compatible_format") conv: ConversationAdapter = self.get_default_conv_template( model_name, model_path ) @@ -277,6 +284,72 @@ def get_prompt_with_template( return None, None, None conv = conv.copy() + if convert_to_compatible_format: + # In old version, we will convert the messages to compatible format + conv = self._set_conv_converted_messages(conv, messages) + else: + # In new version, we will use the messages directly + conv = self._set_conv_messages(conv, messages) + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + new_prompt = conv.get_prompt() + return new_prompt, conv.stop_str, conv.stop_token_ids + + def _set_conv_messages( + self, conv: ConversationAdapter, messages: List[ModelMessage] + ) -> ConversationAdapter: + """Set the messages to the conversation template + + Args: + conv (ConversationAdapter): The conversation template + messages (List[ModelMessage]): The model messages + + Returns: + ConversationAdapter: The conversation template with messages + """ + system_messages = [] + for message in messages: + if isinstance(message, ModelMessage): + role = message.role + content = message.content + elif isinstance(message, dict): + role = message["role"] + content = message["content"] + else: + raise ValueError(f"Invalid message type: {message}") + + if role == ModelMessageRoleType.SYSTEM: + system_messages.append(content) + elif role == ModelMessageRoleType.HUMAN: + conv.append_message(conv.roles[0], content) + elif role == ModelMessageRoleType.AI: + conv.append_message(conv.roles[1], content) + else: + raise ValueError(f"Unknown role: {role}") + if len(system_messages) > 1: + raise ValueError( + f"Your system messages have more than one message: {system_messages}" + ) + if system_messages: + conv.set_system_message(system_messages[0]) + return conv + + def _set_conv_converted_messages( + self, conv: ConversationAdapter, messages: List[ModelMessage] + ) -> ConversationAdapter: + """Set the messages to the conversation template + + In the old version, we will convert the messages to compatible format. + This method will be deprecated in the future. + + Args: + conv (ConversationAdapter): The conversation template + messages (List[ModelMessage]): The model messages + + Returns: + ConversationAdapter: The conversation template with messages + """ system_messages = [] user_messages = [] ai_messages = [] @@ -295,10 +368,8 @@ def get_prompt_with_template( # Support for multiple system messages system_messages.append(content) elif role == ModelMessageRoleType.HUMAN: - # conv.append_message(conv.roles[0], content) user_messages.append(content) elif role == ModelMessageRoleType.AI: - # conv.append_message(conv.roles[1], content) ai_messages.append(content) else: raise ValueError(f"Unknown role: {role}") @@ -320,10 +391,7 @@ def get_prompt_with_template( # TODO join all system messages may not be a good idea conv.set_system_message("".join(can_use_systems)) - # Add a blank message for the assistant. - conv.append_message(conv.roles[1], None) - new_prompt = conv.get_prompt() - return new_prompt, conv.stop_str, conv.stop_token_ids + return conv def model_adaptation( self, @@ -335,6 +403,15 @@ def model_adaptation( ) -> Tuple[Dict, Dict]: """Params adaptation""" messages = params.get("messages") + convert_to_compatible_format = params.get("convert_to_compatible_format") + message_version = params.get("version", "v2").lower() + logger.info(f"Message version is {message_version}") + if convert_to_compatible_format is None: + # Support convert messages to compatible format when message version is v1 + convert_to_compatible_format = message_version == "v1" + # Save to params + params["convert_to_compatible_format"] = convert_to_compatible_format + # Some model context to dbgpt server model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} if messages: @@ -345,7 +422,9 @@ def model_adaptation( ] params["messages"] = messages - new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template) + new_prompt = self.get_str_prompt( + params, messages, tokenizer, prompt_template, convert_to_compatible_format + ) conv_stop_str, conv_stop_token_ids = None, None if not new_prompt: ( diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py index 3fd9eb877..204133b68 100644 --- a/dbgpt/model/adapter/hf_adapter.py +++ b/dbgpt/model/adapter/hf_adapter.py @@ -87,6 +87,7 @@ def get_str_prompt( messages: List[ModelMessage], tokenizer: Any, prompt_template: str = None, + convert_to_compatible_format: bool = False, ) -> Optional[str]: from transformers import AutoTokenizer @@ -94,7 +95,7 @@ def get_str_prompt( raise ValueError("tokenizer is is None") tokenizer: AutoTokenizer = tokenizer - messages = self.transform_model_messages(messages) + messages = self.transform_model_messages(messages, convert_to_compatible_format) logger.debug(f"The messages after transform: \n{messages}") str_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True diff --git a/dbgpt/model/cluster/base.py b/dbgpt/model/cluster/base.py index cb3f34732..97b76fd30 100644 --- a/dbgpt/model/cluster/base.py +++ b/dbgpt/model/cluster/base.py @@ -22,6 +22,8 @@ class PromptRequest(BaseModel): span_id: str = None metrics: bool = False """Whether to return metrics of inference""" + version: str = "v2" + """Message version, default to v2""" class EmbeddingsRequest(BaseModel): diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py index 10b1cfb7d..426188312 100644 --- a/dbgpt/model/cluster/client.py +++ b/dbgpt/model/cluster/client.py @@ -1,20 +1,35 @@ -from typing import AsyncIterator, List import asyncio -from dbgpt.core.interface.llm import LLMClient, ModelRequest, ModelOutput, ModelMetadata -from dbgpt.model.parameter import WorkerType +from typing import AsyncIterator, List, Optional + +from dbgpt.core.interface.llm import ( + LLMClient, + MessageConverter, + ModelMetadata, + ModelOutput, + ModelRequest, +) from dbgpt.model.cluster.manager_base import WorkerManager +from dbgpt.model.parameter import WorkerType class DefaultLLMClient(LLMClient): def __init__(self, worker_manager: WorkerManager): self._worker_manager = worker_manager - async def generate(self, request: ModelRequest) -> ModelOutput: + async def generate( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelOutput: + request = await self.covert_message(request, message_converter) return await self._worker_manager.generate(request.to_dict()) async def generate_stream( - self, request: ModelRequest + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, ) -> AsyncIterator[ModelOutput]: + request = await self.covert_message(request, message_converter) async for output in self._worker_manager.generate_stream(request.to_dict()): yield output diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index 26f2b40ae..453071004 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -8,7 +8,12 @@ from dbgpt.configs.model_config import get_device from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.model_adapter import get_llm_model_adapter -from dbgpt.core import ModelOutput, ModelInferenceMetrics, ModelMetadata +from dbgpt.core import ( + ModelOutput, + ModelInferenceMetrics, + ModelMetadata, + ModelExtraMedata, +) from dbgpt.model.loader import ModelLoader, _get_model_real_path from dbgpt.model.parameter import ModelParameters from dbgpt.model.cluster.worker_base import ModelWorker @@ -196,9 +201,13 @@ async def async_count_token(self, prompt: str) -> int: raise NotImplementedError def get_model_metadata(self, params: Dict) -> ModelMetadata: + ext_metadata = ModelExtraMedata( + prompt_sep=self.llm_adapter.get_default_message_separator() + ) return ModelMetadata( model=self.model_name, context_length=self.context_len, + ext_metadata=ext_metadata, ) async def async_get_model_metadata(self, params: Dict) -> ModelMetadata: diff --git a/dbgpt/model/cluster/worker/remote_worker.py b/dbgpt/model/cluster/worker/remote_worker.py index 895d998d3..00e707768 100644 --- a/dbgpt/model/cluster/worker/remote_worker.py +++ b/dbgpt/model/cluster/worker/remote_worker.py @@ -122,7 +122,7 @@ async def async_get_model_metadata(self, params: Dict) -> ModelMetadata: json=params, timeout=self.timeout, ) - return ModelMetadata(**response.json()) + return ModelMetadata.from_dict(response.json()) def get_model_metadata(self, params: Dict) -> ModelMetadata: """Get model metadata""" diff --git a/dbgpt/model/operator/__init__.py b/dbgpt/model/operator/__init__.py index e69de29bb..b0b81d551 100644 --- a/dbgpt/model/operator/__init__.py +++ b/dbgpt/model/operator/__init__.py @@ -0,0 +1,13 @@ +from dbgpt.model.operator.llm_operator import ( + LLMOperator, + MixinLLMOperator, + StreamingLLMOperator, +) +from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator + +__ALL__ = [ + "MixinLLMOperator", + "LLMOperator", + "StreamingLLMOperator", + "OpenAIStreamingOutputOperator", +] diff --git a/dbgpt/model/operator/llm_operator.py b/dbgpt/model/operator/llm_operator.py new file mode 100644 index 000000000..c1d6ef068 --- /dev/null +++ b/dbgpt/model/operator/llm_operator.py @@ -0,0 +1,75 @@ +import logging +from abc import ABC +from typing import Optional + +from dbgpt.component import ComponentType +from dbgpt.core import LLMClient +from dbgpt.core.awel import BaseOperator +from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator +from dbgpt.model.cluster import WorkerManagerFactory + +logger = logging.getLogger(__name__) + + +class MixinLLMOperator(BaseLLM, BaseOperator, ABC): + """Mixin class for LLM operator. + + This class extends BaseOperator by adding LLM capabilities. + """ + + def __init__(self, default_client: Optional[LLMClient] = None, **kwargs): + super().__init__(default_client) + self._default_llm_client = default_client + + @property + def llm_client(self) -> LLMClient: + if not self._llm_client: + worker_manager_factory: WorkerManagerFactory = ( + self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, + WorkerManagerFactory, + default_component=None, + ) + ) + if worker_manager_factory: + from dbgpt.model.cluster.client import DefaultLLMClient + + self._llm_client = DefaultLLMClient(worker_manager_factory.create()) + else: + if self._default_llm_client is None: + from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient + + self._default_llm_client = OpenAILLMClient() + logger.info( + f"Can't find worker manager factory, use default llm client {self._default_llm_client}." + ) + self._llm_client = self._default_llm_client + return self._llm_client + + +class LLMOperator(MixinLLMOperator, BaseLLMOperator): + """Default LLM operator. + + Args: + llm_client (Optional[LLMClient], optional): The LLM client. Defaults to None. + If llm_client is None, we will try to connect to the model serving cluster deploy by DB-GPT, + and if we can't connect to the model serving cluster, we will use the :class:`OpenAILLMClient` as the llm_client. + """ + + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + BaseLLMOperator.__init__(self, llm_client, **kwargs) + + +class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator): + """Default streaming LLM operator. + + Args: + llm_client (Optional[LLMClient], optional): The LLM client. Defaults to None. + If llm_client is None, we will try to connect to the model serving cluster deploy by DB-GPT, + and if we can't connect to the model serving cluster, we will use the :class:`OpenAILLMClient` as the llm_client. + """ + + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs) diff --git a/dbgpt/model/operator/model_operator.py b/dbgpt/model/operator/model_operator.py index 9ac82d026..061d6fdf0 100644 --- a/dbgpt/model/operator/model_operator.py +++ b/dbgpt/model/operator/model_operator.py @@ -1,17 +1,18 @@ -from typing import AsyncIterator, Dict, List, Union import logging +from typing import AsyncIterator, Dict, List, Union + +from dbgpt.component import ComponentType +from dbgpt.core import ModelOutput from dbgpt.core.awel import ( BranchFunc, - StreamifyAbsOperator, BranchOperator, MapOperator, + StreamifyAbsOperator, TransformStreamAbsOperator, ) -from dbgpt.component import ComponentType from dbgpt.core.awel.operator.base import BaseOperator -from dbgpt.core import ModelOutput from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory -from dbgpt.storage.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue +from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue logger = logging.getLogger(__name__) diff --git a/dbgpt/model/proxy/llms/bard.py b/dbgpt/model/proxy/llms/bard.py index fc398fe8b..7e43661d8 100755 --- a/dbgpt/model/proxy/llms/bard.py +++ b/dbgpt/model/proxy/llms/bard.py @@ -13,6 +13,8 @@ def bard_generate_stream( proxy_api_key = model_params.proxy_api_key proxy_server_url = model_params.proxy_server_url + convert_to_compatible_format = params.get("convert_to_compatible_format", False) + history = [] messages: List[ModelMessage] = params["messages"] for message in messages: @@ -25,14 +27,15 @@ def bard_generate_stream( else: pass - last_user_input_index = None - for i in range(len(history) - 1, -1, -1): - if history[i]["role"] == "user": - last_user_input_index = i - break - if last_user_input_index: - last_user_input = history.pop(last_user_input_index) - history.append(last_user_input) + if convert_to_compatible_format: + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i + break + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) + history.append(last_user_input) msgs = [] for msg in history: diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index ece9c780c..5d1882141 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -128,7 +128,10 @@ def _build_request(model: ProxyModel, params): messages: List[ModelMessage] = params["messages"] # history = __convert_2_gpt_messages(messages) - history = ModelMessage.to_openai_messages(messages) + convert_to_compatible_format = params.get("convert_to_compatible_format", False) + history = ModelMessage.to_openai_messages( + messages, convert_to_compatible_format=convert_to_compatible_format + ) payloads = { "temperature": params.get("temperature"), "max_tokens": params.get("max_new_tokens"), diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py index 068c19cfb..04975817c 100644 --- a/dbgpt/model/proxy/llms/gemini.py +++ b/dbgpt/model/proxy/llms/gemini.py @@ -12,7 +12,6 @@ def gemini_generate_stream( """Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview""" model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") - global history # TODO proxy model use unified config? proxy_api_key = model_params.proxy_api_key diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py index 0b39c6cf8..aedf8a951 100644 --- a/dbgpt/model/proxy/llms/spark.py +++ b/dbgpt/model/proxy/llms/spark.py @@ -56,6 +56,9 @@ def spark_generate_stream( del messages[index] break + # TODO: Support convert_to_compatible_format config + convert_to_compatible_format = params.get("convert_to_compatible_format", False) + history = [] # Add history conversation for message in messages: diff --git a/dbgpt/model/proxy/llms/tongyi.py b/dbgpt/model/proxy/llms/tongyi.py index 902b26c4a..bbcec2f42 100644 --- a/dbgpt/model/proxy/llms/tongyi.py +++ b/dbgpt/model/proxy/llms/tongyi.py @@ -53,8 +53,12 @@ def tongyi_generate_stream( proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo messages: List[ModelMessage] = params["messages"] + convert_to_compatible_format = params.get("convert_to_compatible_format", False) - history = __convert_2_tongyi_messages(messages) + if convert_to_compatible_format: + history = __convert_2_tongyi_messages(messages) + else: + history = ModelMessage.to_openai_messages(messages) gen = Generation() res = gen.call( proxyllm_backend, diff --git a/dbgpt/model/proxy/llms/wenxin.py b/dbgpt/model/proxy/llms/wenxin.py index 9d31fac60..28e5a65d0 100644 --- a/dbgpt/model/proxy/llms/wenxin.py +++ b/dbgpt/model/proxy/llms/wenxin.py @@ -25,8 +25,29 @@ def _build_access_token(api_key: str, secret_key: str) -> str: return res.json().get("access_token") +def _to_wenxin_messages(messages: List[ModelMessage]): + """Convert messages to wenxin compatible format + + See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11 + """ + wenxin_messages = [] + system_messages = [] + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + wenxin_messages.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + system_messages.append(message.content) + elif message.role == ModelMessageRoleType.AI: + wenxin_messages.append({"role": "assistant", "content": message.content}) + else: + pass + if len(system_messages) > 1: + raise ValueError("Wenxin only support one system message") + str_system_message = system_messages[0] if len(system_messages) > 0 else "" + return wenxin_messages, str_system_message + + def __convert_2_wenxin_messages(messages: List[ModelMessage]): - chat_round = 0 wenxin_messages = [] last_usr_message = "" @@ -57,7 +78,8 @@ def __convert_2_wenxin_messages(messages: List[ModelMessage]): last_message = messages[-1] end_message = last_message.content wenxin_messages.append({"role": "user", "content": end_message}) - return wenxin_messages, system_messages + str_system_message = system_messages[0] if len(system_messages) > 0 else "" + return wenxin_messages, str_system_message def wenxin_generate_stream( @@ -87,13 +109,14 @@ def wenxin_generate_stream( messages: List[ModelMessage] = params["messages"] - history, systems = __convert_2_wenxin_messages(messages) - system = "" - if systems and len(systems) > 0: - system = systems[0] + convert_to_compatible_format = params.get("convert_to_compatible_format", False) + if convert_to_compatible_format: + history, system_message = __convert_2_wenxin_messages(messages) + else: + history, system_message = _to_wenxin_messages(messages) payload = { "messages": history, - "system": system, + "system": system_message, "temperature": params.get("temperature"), "stream": True, } diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 90f1d3d2b..2f854e17a 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -57,6 +57,10 @@ def zhipu_generate_stream( zhipuai.api_key = proxy_api_key messages: List[ModelMessage] = params["messages"] + + # TODO: Support convert_to_compatible_format config, zhipu not support system message + convert_to_compatible_format = params.get("convert_to_compatible_format", False) + history, systems = __convert_2_zhipu_messages(messages) res = zhipuai.model_api.sse_invoke( model=proxyllm_backend, diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index 4e9cfc353..7d1344aec 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -20,8 +20,13 @@ from dbgpt.component import ComponentType from dbgpt.core.operator import BaseLLM from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator -from dbgpt.core.interface.llm import ModelMetadata, LLMClient -from dbgpt.core.interface.llm import ModelOutput, ModelRequest +from dbgpt.core.interface.llm import ( + ModelOutput, + ModelRequest, + ModelMetadata, + LLMClient, + MessageConverter, +) from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory from dbgpt._private.pydantic import model_to_json @@ -175,7 +180,13 @@ def _build_request( payload["max_tokens"] = request.max_new_tokens return payload - async def generate(self, request: ModelRequest) -> ModelOutput: + async def generate( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelOutput: + request = await self.covert_message(request, message_converter) + messages = request.to_openai_messages() payload = self._build_request(request) logger.info( @@ -195,8 +206,11 @@ async def generate(self, request: ModelRequest) -> ModelOutput: ) async def generate_stream( - self, request: ModelRequest + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, ) -> AsyncIterator[ModelOutput]: + request = await self.covert_message(request, message_converter) messages = request.to_openai_messages() payload = self._build_request(request, True) logger.info( @@ -247,7 +261,7 @@ async def count_token(self, model: str, prompt: str) -> int: return self._tokenizer.count_token(prompt, model) -class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]): +class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]): """Transform ModelOutput to openai stream format.""" async def transform_stream( @@ -266,40 +280,6 @@ async def model_caller() -> str: yield output -class MixinLLMOperator(BaseLLM, BaseOperator, ABC): - """Mixin class for LLM operator. - - This class extends BaseOperator by adding LLM capabilities. - """ - - def __init__(self, default_client: Optional[LLMClient] = None, **kwargs): - super().__init__(default_client) - self._default_llm_client = default_client - - @property - def llm_client(self) -> LLMClient: - if not self._llm_client: - worker_manager_factory: WorkerManagerFactory = ( - self.system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, - WorkerManagerFactory, - default_component=None, - ) - ) - if worker_manager_factory: - self._llm_client = DefaultLLMClient(worker_manager_factory.create()) - else: - if self._default_llm_client is None: - from dbgpt.model import OpenAILLMClient - - self._default_llm_client = OpenAILLMClient() - logger.info( - f"Can't find worker manager factory, use default llm client {self._default_llm_client}." - ) - self._llm_client = self._default_llm_client - return self._llm_client - - async def _to_openai_stream( output_iter: AsyncIterator[ModelOutput], model: Optional[str] = None, diff --git a/dbgpt/serve/conversation/operator.py b/dbgpt/serve/conversation/operator.py new file mode 100644 index 000000000..ad289815c --- /dev/null +++ b/dbgpt/serve/conversation/operator.py @@ -0,0 +1,71 @@ +import logging +from typing import Any, Optional + +from dbgpt.core import ( + InMemoryStorage, + MessageStorageItem, + StorageConversation, + StorageInterface, +) +from dbgpt.core.operator import PreChatHistoryLoadOperator + +from .serve import Serve + +logger = logging.getLogger(__name__) + + +class ServePreChatHistoryLoadOperator(PreChatHistoryLoadOperator): + """Pre-chat history load operator for DB-GPT serve component + + Args: + storage (Optional[StorageInterface[StorageConversation, Any]], optional): + The conversation storage, store the conversation items. Defaults to None. + message_storage (Optional[StorageInterface[MessageStorageItem, Any]], optional): + The message storage, store the messages of one conversation. Defaults to None. + + If the storage or message_storage is not None, the storage or message_storage will be used first. + Otherwise, we will try get current serve component from system app, + and use the storage or message_storage of the serve component. + If we can't get the storage, we will use the InMemoryStorage as the storage or message_storage. + """ + + def __init__( + self, + storage: Optional[StorageInterface[StorageConversation, Any]] = None, + message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + **kwargs, + ): + super().__init__(storage, message_storage, **kwargs) + + @property + def storage(self): + if self._storage: + return self._storage + storage = Serve.call_on_current_serve( + self.system_app, lambda serve: serve.conv_storage + ) + if not storage: + logger.warning( + "Can't get the conversation storage from current serve component, " + "use the InMemoryStorage as the conversation storage." + ) + self._storage = InMemoryStorage() + return self._storage + return storage + + @property + def message_storage(self): + if self._message_storage: + return self._message_storage + storage = Serve.call_on_current_serve( + self.system_app, + lambda serve: serve.message_storage, + ) + if not storage: + logger.warning( + "Can't get the message storage from current serve component, " + "use the InMemoryStorage as the message storage." + ) + self._message_storage = InMemoryStorage() + return self._message_storage + return storage diff --git a/dbgpt/serve/core/serve.py b/dbgpt/serve/core/serve.py index b2bcfdb1f..e909ad4cb 100644 --- a/dbgpt/serve/core/serve.py +++ b/dbgpt/serve/core/serve.py @@ -1,6 +1,6 @@ import logging from abc import ABC -from typing import List, Optional, Union +from typing import Any, Callable, List, Optional, Union from sqlalchemy import URL @@ -60,3 +60,44 @@ def create_or_get_db_manager(self) -> DatabaseManager: finally: self._not_create_table = False return init_db + + @classmethod + def get_current_serve(cls, system_app: SystemApp) -> Optional["BaseServe"]: + """Get the current serve component. + + None if the serve component is not exist. + + Args: + system_app (SystemApp): The system app + + Returns: + Optional[BaseServe]: The current serve component. + """ + return system_app.get_component(cls.name, cls, default_component=None) + + @classmethod + def call_on_current_serve( + cls, + system_app: SystemApp, + func: Callable[["BaseServe"], Optional[Any]], + default_value: Optional[Any] = None, + ) -> Optional[Any]: + """Call the function on the current serve component. + + Return default_value if the serve component is not exist or the function return None. + + Args: + system_app (SystemApp): The system app + func (Callable[[BaseServe], Any]): The function to call + default_value (Optional[Any], optional): The default value. Defaults to None. + + Returns: + Optional[Any]: The result of the function + """ + serve = cls.get_current_serve(system_app) + if not serve: + return default_value + result = func(serve) + if not result: + result = default_value + return result diff --git a/dbgpt/util/function_utils.py b/dbgpt/util/function_utils.py new file mode 100644 index 000000000..5bfd578ea --- /dev/null +++ b/dbgpt/util/function_utils.py @@ -0,0 +1,87 @@ +from typing import Any, get_type_hints, get_origin, get_args +from functools import wraps +import inspect +import asyncio + + +def _is_instance_of_generic_type(obj, generic_type): + """Check if an object is an instance of a generic type.""" + if generic_type is Any: + return True # Any type is compatible with any object + + origin = get_origin(generic_type) + if origin is None: + return isinstance(obj, generic_type) # Handle non-generic types + + args = get_args(generic_type) + if not args: + return isinstance(obj, origin) + + # Check if object matches the generic origin (like list, dict) + if not isinstance(obj, origin): + return False + + # For each item in the object, check if it matches the corresponding type argument + for sub_obj, arg in zip(obj, args): + # Skip check if the type argument is Any + if arg is not Any and not isinstance(sub_obj, arg): + return False + + return True + + +def _sort_args(func, args, kwargs): + sig = inspect.signature(func) + type_hints = get_type_hints(func) + + arg_types = [ + type_hints[param_name] + for param_name in sig.parameters + if param_name != "return" and param_name != "self" + ] + + if "self" in sig.parameters: + self_arg = [args[0]] + other_args = args[1:] + else: + self_arg = [] + other_args = args + + sorted_args = sorted( + other_args, + key=lambda x: next( + i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t) + ), + ) + return (*self_arg, *sorted_args), kwargs + + +def rearrange_args_by_type(func): + """Decorator to rearrange the arguments of a function by type. + + Examples: + + .. code-block:: python + + from dbgpt.util.function_utils import rearrange_args_by_type + + @rearrange_args_by_type + def sync_regular_function(a: int, b: str, c: float): + return a, b, c + + assert instance.sync_class_method(1, "b", 3.0) == (1, "b", 3.0) + assert instance.sync_class_method("b", 3.0, 1) == (1, "b", 3.0) + + """ + + @wraps(func) + def sync_wrapper(*args, **kwargs): + sorted_args, sorted_kwargs = _sort_args(func, args, kwargs) + return func(*sorted_args, **sorted_kwargs) + + @wraps(func) + async def async_wrapper(*args, **kwargs): + sorted_args, sorted_kwargs = _sort_args(func, args, kwargs) + return await func(*sorted_args, **sorted_kwargs) + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper diff --git a/dbgpt/util/prompt_util.py b/dbgpt/util/prompt_util.py index 17d994d32..e0c0a3846 100644 --- a/dbgpt/util/prompt_util.py +++ b/dbgpt/util/prompt_util.py @@ -10,11 +10,12 @@ import logging from string import Formatter -from typing import Callable, List, Optional, Sequence +from typing import Callable, List, Optional, Sequence, Set from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel from dbgpt.util.global_helper import globals_helper +from dbgpt.core.interface.prompt import get_template_vars from dbgpt._private.llm_metadata import LLMMetadata from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter @@ -230,15 +231,3 @@ def get_empty_prompt_txt(template: str) -> str: all_kwargs = {**partial_kargs, **empty_kwargs} prompt = template.format(**all_kwargs) return prompt - - -def get_template_vars(template_str: str) -> List[str]: - """Get template variables from a template string.""" - variables = [] - formatter = Formatter() - - for _, variable_name, _, _ in formatter.parse(template_str): - if variable_name: - variables.append(variable_name) - - return variables diff --git a/dbgpt/util/tests/test_function_utils.py b/dbgpt/util/tests/test_function_utils.py new file mode 100644 index 000000000..f245b8c14 --- /dev/null +++ b/dbgpt/util/tests/test_function_utils.py @@ -0,0 +1,120 @@ +from typing import List, Dict, Any + +import pytest +from dbgpt.util.function_utils import rearrange_args_by_type + + +class ChatPromptTemplate: + pass + + +class BaseMessage: + pass + + +class ModelMessage: + pass + + +class DummyClass: + @rearrange_args_by_type + async def class_method(self, a: int, b: str, c: float): + return a, b, c + + @rearrange_args_by_type + async def merge_history( + self, + prompt: ChatPromptTemplate, + history: List[BaseMessage], + prompt_dict: Dict[str, Any], + ) -> List[ModelMessage]: + return [type(prompt), type(history), type(prompt_dict)] + + @rearrange_args_by_type + def sync_class_method(self, a: int, b: str, c: float): + return a, b, c + + +@rearrange_args_by_type +def sync_regular_function(a: int, b: str, c: float): + return a, b, c + + +@rearrange_args_by_type +async def regular_function(a: int, b: str, c: float): + return a, b, c + + +@pytest.mark.asyncio +async def test_class_method_correct_order(): + instance = DummyClass() + result = await instance.class_method(1, "b", 3.0) + assert result == (1, "b", 3.0), "Class method failed with correct order" + + +@pytest.mark.asyncio +async def test_class_method_incorrect_order(): + instance = DummyClass() + result = await instance.class_method("b", 3.0, 1) + assert result == (1, "b", 3.0), "Class method failed with incorrect order" + + +@pytest.mark.asyncio +async def test_regular_function_correct_order(): + result = await regular_function(1, "b", 3.0) + assert result == (1, "b", 3.0), "Regular function failed with correct order" + + +@pytest.mark.asyncio +async def test_regular_function_incorrect_order(): + result = await regular_function("b", 3.0, 1) + assert result == (1, "b", 3.0), "Regular function failed with incorrect order" + + +@pytest.mark.asyncio +async def test_merge_history_correct_order(): + instance = DummyClass() + result = await instance.merge_history( + ChatPromptTemplate(), [BaseMessage()], {"key": "value"} + ) + assert result == [ChatPromptTemplate, list, dict], "Failed with correct order" + + +@pytest.mark.asyncio +async def test_merge_history_incorrect_order_1(): + instance = DummyClass() + result = await instance.merge_history( + [BaseMessage()], ChatPromptTemplate(), {"key": "value"} + ) + assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 1" + + +@pytest.mark.asyncio +async def test_merge_history_incorrect_order_2(): + instance = DummyClass() + result = await instance.merge_history( + {"key": "value"}, [BaseMessage()], ChatPromptTemplate() + ) + assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 2" + + +def test_sync_class_method_correct_order(): + instance = DummyClass() + result = instance.sync_class_method(1, "b", 3.0) + assert result == (1, "b", 3.0), "Sync class method failed with correct order" + + +def test_sync_class_method_incorrect_order(): + instance = DummyClass() + result = instance.sync_class_method("b", 3.0, 1) + assert result == (1, "b", 3.0), "Sync class method failed with incorrect order" + + +def test_sync_regular_function_correct_order(): + result = sync_regular_function(1, "b", 3.0) + assert result == (1, "b", 3.0), "Sync regular function failed with correct order" + + +def test_sync_regular_function_incorrect_order(): + result = sync_regular_function("b", 3.0, 1) + assert result == (1, "b", 3.0), "Sync regular function failed with incorrect order" diff --git a/examples/awel/data_analyst_assistant.py b/examples/awel/data_analyst_assistant.py index f40a5be9e..7997a60e0 100644 --- a/examples/awel/data_analyst_assistant.py +++ b/examples/awel/data_analyst_assistant.py @@ -26,7 +26,7 @@ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \ -H "Content-Type: application/json" -d '{ "command": "dbgpt_awel_data_analyst_code_fix", - "model": "gpt-3.5-turbo", + "model": "'"$MODEL"'", "stream": false, "context": { "conv_uid": "uuid_conv_copilot_1234", @@ -37,43 +37,55 @@ """ import logging +import os from functools import cache from typing import Any, Dict, List, Optional from dbgpt._private.pydantic import BaseModel, Field from dbgpt.core import ( - InMemoryStorage, - LLMClient, - MessageStorageItem, + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, ModelMessage, - ModelMessageRoleType, + ModelRequest, + ModelRequestContext, PromptManager, PromptTemplate, - StorageConversation, - StorageInterface, + SystemPromptTemplate, ) from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.core.operator import ( BufferedConversationMapperOperator, + HistoryDynamicPromptBuilderOperator, LLMBranchOperator, + RequestBuilderOperator, +) +from dbgpt.model.operator import ( LLMOperator, - PostConversationOperator, - PostStreamingConversationOperator, - PreConversationOperator, - RequestBuildOperator, + OpenAIStreamingOutputOperator, StreamingLLMOperator, ) -from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator -from dbgpt.util.utils import colored +from dbgpt.serve.conversation.operator import ServePreChatHistoryLoadOperator logger = logging.getLogger(__name__) +PROMPT_LANG_ZH = "zh" +PROMPT_LANG_EN = "en" + +CODE_DEFAULT = "dbgpt_awel_data_analyst_code_default" CODE_FIX = "dbgpt_awel_data_analyst_code_fix" CODE_PERF = "dbgpt_awel_data_analyst_code_perf" CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain" CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment" CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate" +CODE_DEFAULT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师。 +你可以根据最佳实践来优化代码, 也可以对代码进行修复, 解释, 添加注释, 以及将代码翻译成其他语言。""" +CODE_DEFAULT_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst. +You can optimize the code according to best practices, or fix, explain, add comments to the code, +and you can also translate the code into other languages. +""" + CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师, 这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。""" CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst, @@ -126,7 +138,9 @@ class ReqContext(BaseModel): class TriggerReqBody(BaseModel): messages: str = Field(..., description="User input messages") - command: Optional[str] = Field(default="fix", description="Command name") + command: Optional[str] = Field( + default=None, description="Command name, None if common chat" + ) model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name") stream: Optional[bool] = Field(default=False, description="Whether return stream") language: Optional[str] = Field(default="hive", description="Language") @@ -140,109 +154,89 @@ class TriggerReqBody(BaseModel): @cache def load_or_save_prompt_template(pm: PromptManager): - ext_params = { + zh_ext_params = { + "chat_scene": "chat_with_code", + "sub_chat_scene": "data_analyst", + "prompt_type": "common", + "prompt_language": PROMPT_LANG_ZH, + } + en_ext_params = { "chat_scene": "chat_with_code", "sub_chat_scene": "data_analyst", "prompt_type": "common", + "prompt_language": PROMPT_LANG_EN, } + + pm.query_or_save( + PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_ZH), + prompt_name=CODE_DEFAULT, + **zh_ext_params, + ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_FIX_TEMPLATE_ZH, - ), + PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_EN), + prompt_name=CODE_DEFAULT, + **en_ext_params, + ) + pm.query_or_save( + PromptTemplate.from_template(CODE_FIX_TEMPLATE_ZH), prompt_name=CODE_FIX, - prompt_language="zh", - **ext_params, + **zh_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_FIX_TEMPLATE_EN, - ), + PromptTemplate.from_template(CODE_FIX_TEMPLATE_EN), prompt_name=CODE_FIX, - prompt_language="en", - **ext_params, + **en_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_PERF_TEMPLATE_ZH, - ), + PromptTemplate.from_template(CODE_PERF_TEMPLATE_ZH), prompt_name=CODE_PERF, - prompt_language="zh", - **ext_params, + **zh_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_PERF_TEMPLATE_EN, - ), + PromptTemplate.from_template(CODE_PERF_TEMPLATE_EN), prompt_name=CODE_PERF, - prompt_language="en", - **ext_params, + **en_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_EXPLAIN_TEMPLATE_ZH, - ), + PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_ZH), prompt_name=CODE_EXPLAIN, - prompt_language="zh", - **ext_params, + **zh_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_EXPLAIN_TEMPLATE_EN, - ), + PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_EN), prompt_name=CODE_EXPLAIN, - prompt_language="en", - **ext_params, + **en_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_COMMENT_TEMPLATE_ZH, - ), + PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_ZH), prompt_name=CODE_COMMENT, - prompt_language="zh", - **ext_params, + **zh_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["language"], - template=CODE_COMMENT_TEMPLATE_EN, - ), + PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_EN), prompt_name=CODE_COMMENT, - prompt_language="en", - **ext_params, + **en_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["source_language", "target_language"], - template=CODE_TRANSLATE_TEMPLATE_ZH, - ), + PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_ZH), prompt_name=CODE_TRANSLATE, - prompt_language="zh", - **ext_params, + **zh_ext_params, ) pm.query_or_save( - PromptTemplate( - input_variables=["source_language", "target_language"], - template=CODE_TRANSLATE_TEMPLATE_EN, - ), + PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_EN), prompt_name=CODE_TRANSLATE, - prompt_language="en", - **ext_params, + **en_ext_params, ) -class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]): +class PromptTemplateBuilderOperator(MapOperator[TriggerReqBody, ChatPromptTemplate]): + """Build prompt template for chat with code.""" + def __init__(self, **kwargs): super().__init__(**kwargs) self._default_prompt_manager = PromptManager() - async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: + async def map(self, input_value: TriggerReqBody) -> ChatPromptTemplate: from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME from dbgpt.serve.prompt.serve import Serve as PromptServe @@ -256,7 +250,24 @@ async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: load_or_save_prompt_template(pm) user_language = self.system_app.config.get_current_lang(default="en") - + if not input_value.command: + # No command, just chat, not include system prompt. + default_prompt_list = pm.prefer_query( + CODE_DEFAULT, prefer_prompt_language=user_language + ) + default_prompt_template = ( + default_prompt_list[0].to_prompt_template().template + ) + prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(default_prompt_template), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{user_input}"), + ] + ) + return prompt + + # Query prompt template from prompt manager by command name prompt_list = pm.prefer_query( input_value.command, prefer_prompt_language=user_language ) @@ -264,109 +275,38 @@ async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}" logger.error(error_msg) raise ValueError(error_msg) - prompt = prompt_list[0].to_prompt_template() - if input_value.command == CODE_TRANSLATE: - format_params = { - "source_language": input_value.language, - "target_language": input_value.target_language, - } - else: - format_params = {"language": input_value.language} - - system_message = prompt.format(**format_params) - messages = [ - ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message), - ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages), - ] - context = input_value.context.dict() if input_value.context else {} - return { - "messages": messages, - "stream": input_value.stream, - "model": input_value.model, - "context": context, - } - - -class MyConversationOperator(PreConversationOperator): - def __init__( - self, - storage: Optional[StorageInterface[StorageConversation, Any]] = None, - message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, - **kwargs, - ): - super().__init__(storage, message_storage, **kwargs) - - def _get_conversion_serve(self): - from dbgpt.serve.conversation.serve import ( - SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME, + prompt_template = prompt_list[0].to_prompt_template() + + return ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template(prompt_template.template), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{user_input}"), + ] ) - from dbgpt.serve.conversation.serve import Serve as ConversationServe - conversation_serve: ConversationServe = self.system_app.get_component( - CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None - ) - return conversation_serve - - @property - def storage(self): - if self._storage: - return self._storage - conversation_serve = self._get_conversion_serve() - if conversation_serve: - return conversation_serve.conv_storage - else: - logger.info("Conversation storage not found, use InMemoryStorage default") - self._storage = InMemoryStorage() - return self._storage - - @property - def message_storage(self): - if self._message_storage: - return self._message_storage - conversation_serve = self._get_conversion_serve() - if conversation_serve: - return conversation_serve.message_storage - else: - logger.info("Message storage not found, use InMemoryStorage default") - self._message_storage = InMemoryStorage() - return self._message_storage - - -class MyLLMOperator(MixinLLMOperator, LLMOperator): - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) - LLMOperator.__init__(self, llm_client, **kwargs) - - -class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) - StreamingLLMOperator.__init__(self, llm_client, **kwargs) - - -def history_message_mapper( - messages_by_round: List[List[ModelMessage]], -) -> List[ModelMessage]: - """Mapper for history conversation. - - If there are multi system messages, just keep the first system message. - """ - has_system_message = False - mapper_messages = [] - for messages in messages_by_round: - for message in messages: - if message.role == ModelMessageRoleType.SYSTEM: - if has_system_message: - continue - else: - mapper_messages.append(message) - has_system_message = True - else: - mapper_messages.append(message) - print("history_message_mapper start:" + "=" * 70) - print(colored(ModelMessage.get_printable_message(mapper_messages), "green")) - print("history_message_mapper end:" + "=" * 72) - return mapper_messages + +def parse_prompt_args(req: TriggerReqBody) -> Dict[str, Any]: + prompt_args = {"user_input": req.messages} + if not req.command: + return prompt_args + if req.command == CODE_TRANSLATE: + prompt_args["source_language"] = req.language + prompt_args["target_language"] = req.target_language + else: + prompt_args["language"] = req.language + return prompt_args + + +async def build_model_request( + messages: List[ModelMessage], req_body: TriggerReqBody +) -> ModelRequest: + return ModelRequest.build_request( + model=req_body.model, + messages=messages, + context=req_body.context, + stream=req_body.stream, + ) with DAG("dbgpt_awel_data_analyst_assistant") as dag: @@ -377,57 +317,59 @@ def history_message_mapper( streaming_predict_func=lambda x: x.stream, ) - copilot_task = CopilotOperator() - request_handle_task = RequestBuildOperator() + prompt_template_load_task = PromptTemplateBuilderOperator() + request_handle_task = RequestBuilderOperator() - # Pre-process conversation - pre_conversation_task = MyConversationOperator() - # Keep last k round conversation. - history_conversation_task = BufferedConversationMapperOperator( - last_k_round=5, message_mapper=history_message_mapper + # Load and store chat history + chat_history_load_task = ServePreChatHistoryLoadOperator() + last_k_round = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_LAST_K_ROUND", 5)) + # History transform task, here we keep last k round messages + history_transform_task = BufferedConversationMapperOperator( + last_k_round=last_k_round + ) + history_prompt_build_task = HistoryDynamicPromptBuilderOperator( + history_key="chat_history" ) - # Save conversation to storage. - post_conversation_task = PostConversationOperator() - # Save streaming conversation to storage. - post_streaming_conversation_task = PostStreamingConversationOperator() + model_request_build_task = JoinOperator(build_model_request) - # Use LLMOperator to generate response. - llm_task = MyLLMOperator(task_name="llm_task") - streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") + # Use BaseLLMOperator to generate response. + llm_task = LLMOperator(task_name="llm_task") + streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task") branch_task = LLMBranchOperator( stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" ) model_parse_task = MapOperator(lambda out: out.to_dict()) - openai_format_stream_task = OpenAIStreamingOperator() + openai_format_stream_task = OpenAIStreamingOutputOperator() result_join_task = JoinOperator( combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out ) + trigger >> prompt_template_load_task >> history_prompt_build_task ( trigger - >> copilot_task - >> request_handle_task - >> pre_conversation_task - >> history_conversation_task - >> branch_task + >> MapOperator( + lambda req: ModelRequestContext( + conv_uid=req.context.conv_uid, + stream=req.stream, + chat_mode=req.context.chat_mode, + ) + ) + >> chat_history_load_task + >> history_transform_task + >> history_prompt_build_task ) + + trigger >> MapOperator(parse_prompt_args) >> history_prompt_build_task + + history_prompt_build_task >> model_request_build_task + trigger >> model_request_build_task + + model_request_build_task >> branch_task # The branch of no streaming response. - ( - branch_task - >> llm_task - >> post_conversation_task - >> model_parse_task - >> result_join_task - ) + (branch_task >> llm_task >> model_parse_task >> result_join_task) # The branch of streaming response. - ( - branch_task - >> streaming_llm_task - >> post_streaming_conversation_task - >> openai_format_stream_task - >> result_join_task - ) + (branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task) if __name__ == "__main__": if dag.leaf_nodes[0].dev_mode: diff --git a/examples/awel/simple_chat_history_example.py b/examples/awel/simple_chat_history_example.py index d1121aa09..c1977117c 100644 --- a/examples/awel/simple_chat_history_example.py +++ b/examples/awel/simple_chat_history_example.py @@ -12,7 +12,7 @@ # Fist round curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ -H "Content-Type: application/json" -d '{ - "model": "gpt-3.5-turbo", + "model": "'"$MODEL"'", "context": { "conv_uid": "uuid_conv_1234" }, @@ -22,7 +22,7 @@ # Second round curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ -H "Content-Type: application/json" -d '{ - "model": "gpt-3.5-turbo", + "model": "'"$MODEL"'", "context": { "conv_uid": "uuid_conv_1234" }, @@ -34,7 +34,7 @@ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ -H "Content-Type: application/json" -d '{ - "model": "gpt-3.5-turbo", + "model": "'"$MODEL"'", "context": { "conv_uid": "uuid_conv_stream_1234" }, @@ -45,7 +45,7 @@ # Second round curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ -H "Content-Type: application/json" -d '{ - "model": "gpt-3.5-turbo", + "model": "'"$MODEL"'", "context": { "conv_uid": "uuid_conv_stream_1234" }, @@ -59,19 +59,27 @@ from typing import Dict, List, Optional, Union from dbgpt._private.pydantic import BaseModel, Field -from dbgpt.core import InMemoryStorage, LLMClient +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + InMemoryStorage, + MessagesPlaceholder, + ModelMessage, + ModelRequest, + ModelRequestContext, + SystemPromptTemplate, +) from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.core.operator import ( - BufferedConversationMapperOperator, + ChatComposerInput, + ChatHistoryPromptComposerOperator, LLMBranchOperator, +) +from dbgpt.model.operator import ( LLMOperator, - PostConversationOperator, - PostStreamingConversationOperator, - PreConversationOperator, - RequestBuildOperator, + OpenAIStreamingOutputOperator, StreamingLLMOperator, ) -from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator logger = logging.getLogger(__name__) @@ -100,16 +108,15 @@ class TriggerReqBody(BaseModel): ) -class MyLLMOperator(MixinLLMOperator, LLMOperator): - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) - LLMOperator.__init__(self, llm_client, **kwargs) - - -class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) - StreamingLLMOperator.__init__(self, llm_client, **kwargs) +async def build_model_request( + messages: List[ModelMessage], req_body: TriggerReqBody +) -> ModelRequest: + return ModelRequest.build_request( + model=req_body.model, + messages=messages, + context=req_body.context, + stream=req_body.stream, + ) with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag: @@ -120,56 +127,53 @@ def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): request_body=TriggerReqBody, streaming_predict_func=lambda req: req.stream, ) - # Transform request body to model request. - request_handle_task = RequestBuildOperator() - # Pre-process conversation, use InMemoryStorage to store conversation. - pre_conversation_task = PreConversationOperator( - storage=InMemoryStorage(), message_storage=InMemoryStorage() + prompt = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template("You are a helpful chatbot."), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{user_input}"), + ] ) - # Keep last k round conversation. - history_conversation_task = BufferedConversationMapperOperator(last_k_round=5) - # Save conversation to storage. - post_conversation_task = PostConversationOperator() - # Save streaming conversation to storage. - post_streaming_conversation_task = PostStreamingConversationOperator() + composer_operator = ChatHistoryPromptComposerOperator( + prompt_template=prompt, + last_k_round=5, + storage=InMemoryStorage(), + message_storage=InMemoryStorage(), + ) - # Use LLMOperator to generate response. - llm_task = MyLLMOperator(task_name="llm_task") - streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") + # Use BaseLLMOperator to generate response. + llm_task = LLMOperator(task_name="llm_task") + streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task") branch_task = LLMBranchOperator( stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" ) model_parse_task = MapOperator(lambda out: out.to_dict()) - openai_format_stream_task = OpenAIStreamingOperator() + openai_format_stream_task = OpenAIStreamingOutputOperator() result_join_task = JoinOperator( combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out ) - ( - trigger - >> request_handle_task - >> pre_conversation_task - >> history_conversation_task - >> branch_task + req_handle_task = MapOperator( + lambda req: ChatComposerInput( + context=ModelRequestContext( + conv_uid=req.context.conv_uid, stream=req.stream + ), + prompt_dict={"user_input": req.messages}, + model_dict={ + "model": req.model, + "context": req.context, + "stream": req.stream, + }, + ) ) + trigger >> req_handle_task >> composer_operator >> branch_task + # The branch of no streaming response. - ( - branch_task - >> llm_task - >> post_conversation_task - >> model_parse_task - >> result_join_task - ) + branch_task >> llm_task >> model_parse_task >> result_join_task # The branch of streaming response. - ( - branch_task - >> streaming_llm_task - >> post_streaming_conversation_task - >> openai_format_stream_task - >> result_join_task - ) + branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task if __name__ == "__main__": if multi_round_dag.leaf_nodes[0].dev_mode: diff --git a/examples/awel/simple_dag_example.py b/examples/awel/simple_dag_example.py index cfe7cd078..64a28ca08 100644 --- a/examples/awel/simple_dag_example.py +++ b/examples/awel/simple_dag_example.py @@ -31,3 +31,11 @@ async def map(self, input_value: TriggerReqBody) -> str: trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody) map_node = RequestHandleOperator() trigger >> map_node + +if __name__ == "__main__": + if dag.leaf_nodes[0].dev_mode: + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag]) + else: + pass diff --git a/examples/awel/simple_llm_client_example.py b/examples/awel/simple_llm_client_example.py index ee7f4292a..71277cd0b 100644 --- a/examples/awel/simple_llm_client_example.py +++ b/examples/awel/simple_llm_client_example.py @@ -8,9 +8,10 @@ .. code-block:: shell DBGPT_SERVER="http://127.0.0.1:5555" + MODEL="gpt-3.5-turbo" curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ -H "Content-Type: application/json" -d '{ - "model": "proxyllm", + "model": "'"$MODEL"'", "messages": "hello" }' @@ -19,7 +20,7 @@ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ -H "Content-Type: application/json" -d '{ - "model": "proxyllm", + "model": "'"$MODEL"'", "messages": "hello", "stream": true }' @@ -29,7 +30,7 @@ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \ -H "Content-Type: application/json" -d '{ - "model": "proxyllm", + "model": "'"$MODEL"'", "messages": "hello" }' @@ -40,13 +41,13 @@ from dbgpt._private.pydantic import BaseModel, Field from dbgpt.core import LLMClient from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator -from dbgpt.core.operator import ( - LLMBranchOperator, +from dbgpt.core.operator import LLMBranchOperator, RequestBuilderOperator +from dbgpt.model.operator import ( LLMOperator, - RequestBuildOperator, + MixinLLMOperator, + OpenAIStreamingOutputOperator, StreamingLLMOperator, ) -from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator logger = logging.getLogger(__name__) @@ -59,18 +60,6 @@ class TriggerReqBody(BaseModel): stream: Optional[bool] = Field(default=False, description="Whether return stream") -class MyLLMOperator(MixinLLMOperator, LLMOperator): - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) - LLMOperator.__init__(self, llm_client, **kwargs) - - -class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) - StreamingLLMOperator.__init__(self, llm_client, **kwargs) - - class MyModelToolOperator( MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]] ): @@ -97,14 +86,14 @@ async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: request_body=TriggerReqBody, streaming_predict_func=lambda req: req.stream, ) - request_handle_task = RequestBuildOperator() - llm_task = MyLLMOperator(task_name="llm_task") - streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") + request_handle_task = RequestBuilderOperator() + llm_task = LLMOperator(task_name="llm_task") + streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task") branch_task = LLMBranchOperator( stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" ) model_parse_task = MapOperator(lambda out: out.to_dict()) - openai_format_stream_task = OpenAIStreamingOperator() + openai_format_stream_task = OpenAIStreamingOutputOperator() result_join_task = JoinOperator( combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out ) diff --git a/examples/sdk/simple_sdk_llm_example.py b/examples/sdk/simple_sdk_llm_example.py index ea479f8aa..ef3c43966 100644 --- a/examples/sdk/simple_sdk_llm_example.py +++ b/examples/sdk/simple_sdk_llm_example.py @@ -1,16 +1,20 @@ import asyncio -from dbgpt.core import BaseOutputParser, PromptTemplate +from dbgpt.core import BaseOutputParser from dbgpt.core.awel import DAG -from dbgpt.core.operator import LLMOperator, RequestBuildOperator +from dbgpt.core.operator import ( + BaseLLMOperator, + PromptBuilderOperator, + RequestBuilderOperator, +) from dbgpt.model import OpenAILLMClient with DAG("simple_sdk_llm_example_dag") as dag: - prompt_task = PromptTemplate.from_template( + prompt_task = PromptBuilderOperator( "Write a SQL of {dialect} to query all data of {table_name}." ) - model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") - llm_task = LLMOperator(OpenAILLMClient()) + model_pre_handle_task = RequestBuilderOperator(model="gpt-3.5-turbo") + llm_task = BaseLLMOperator(OpenAILLMClient()) out_parse_task = BaseOutputParser() prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task diff --git a/examples/sdk/simple_sdk_llm_sql_example.py b/examples/sdk/simple_sdk_llm_sql_example.py index 7a12ec8b2..e42705e42 100644 --- a/examples/sdk/simple_sdk_llm_sql_example.py +++ b/examples/sdk/simple_sdk_llm_sql_example.py @@ -2,7 +2,7 @@ import json from typing import Dict, List -from dbgpt.core import PromptTemplate, SQLOutputParser +from dbgpt.core import SQLOutputParser from dbgpt.core.awel import ( DAG, InputOperator, @@ -10,7 +10,11 @@ MapOperator, SimpleCallDataInputSource, ) -from dbgpt.core.operator import LLMOperator, RequestBuildOperator +from dbgpt.core.operator import ( + BaseLLMOperator, + PromptBuilderOperator, + RequestBuilderOperator, +) from dbgpt.datasource.operator.datasource_operator import DatasourceOperator from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.model import OpenAILLMClient @@ -116,9 +120,9 @@ def _combine_result(self, sql_result_df, model_result: Dict) -> Dict: retriever_task = DatasourceRetrieverOperator(connection=db_connection) # Merge the input data and the table structure information. prompt_input_task = JoinOperator(combine_function=_join_func) - prompt_task = PromptTemplate.from_template(_sql_prompt()) - model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") - llm_task = LLMOperator(OpenAILLMClient()) + prompt_task = PromptBuilderOperator(_sql_prompt()) + model_pre_handle_task = RequestBuilderOperator(model="gpt-3.5-turbo") + llm_task = BaseLLMOperator(OpenAILLMClient()) out_parse_task = SQLOutputParser() sql_parse_task = MapOperator(map_function=lambda x: x["sql"]) db_query_task = DatasourceOperator(connection=db_connection) diff --git a/setup.py b/setup.py index b4b931749..2f318c9a0 100644 --- a/setup.py +++ b/setup.py @@ -411,6 +411,8 @@ def core_requires(): "aiofiles", # for agent "GitPython", + # For AWEL dag visualization, graphviz is a small package, also we can move it to default. + "graphviz", ]