From a99179d748aa39b24c13f58aab68019b86fa500a Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Wed, 27 Dec 2023 23:26:28 +0800 Subject: [PATCH] feat(core): Support multi round conversation operator (#986) --- assets/schema/knowledge_management.sql | 1 + dbgpt/_private/pydantic.py | 10 + dbgpt/app/base.py | 24 +- dbgpt/app/component_configs.py | 14 +- dbgpt/app/dbgpt_server.py | 8 +- .../initialization/serve_initialization.py | 23 +- dbgpt/app/openapi/api_v1/api_v1.py | 4 + dbgpt/component.py | 33 +- dbgpt/core/awel/__init__.py | 51 +- dbgpt/core/awel/dag/base.py | 16 +- dbgpt/core/awel/dag/dag_manager.py | 10 +- dbgpt/core/awel/dag/loader.py | 28 +- dbgpt/core/awel/dag/tests/test_dag.py | 6 +- dbgpt/core/awel/operator/base.py | 34 +- dbgpt/core/awel/operator/common_operator.py | 24 +- dbgpt/core/awel/operator/stream_operator.py | 5 +- dbgpt/core/awel/runner/job_manager.py | 8 +- dbgpt/core/awel/runner/local_runner.py | 5 +- dbgpt/core/awel/task/base.py | 10 +- dbgpt/core/awel/task/task_impl.py | 18 +- dbgpt/core/awel/tests/conftest.py | 12 +- dbgpt/core/awel/tests/test_http_operator.py | 16 +- dbgpt/core/awel/tests/test_run_dag.py | 16 +- dbgpt/core/awel/trigger/http_trigger.py | 8 +- dbgpt/core/awel/trigger/trigger_manager.py | 6 +- dbgpt/core/interface/llm.py | 13 +- dbgpt/core/interface/message.py | 18 +- .../interface/operator/message_operator.py | 179 +++++-- dbgpt/core/interface/prompt.py | 31 +- dbgpt/model/cluster/apiserver/api.py | 4 +- dbgpt/model/utils/chatgpt_utils.py | 13 +- dbgpt/serve/conversation/__init__.py | 2 + dbgpt/serve/conversation/api/__init__.py | 2 + dbgpt/serve/conversation/api/endpoints.py | 176 +++++++ dbgpt/serve/conversation/api/schemas.py | 20 + dbgpt/serve/conversation/config.py | 23 + dbgpt/serve/conversation/dependencies.py | 1 + dbgpt/serve/conversation/models/__init__.py | 2 + dbgpt/serve/conversation/models/models.py | 68 +++ dbgpt/serve/conversation/serve.py | 99 ++++ dbgpt/serve/conversation/service/__init__.py | 0 dbgpt/serve/conversation/service/service.py | 116 +++++ dbgpt/serve/conversation/tests/__init__.py | 0 .../conversation/tests/test_endpoints.py | 124 +++++ dbgpt/serve/conversation/tests/test_models.py | 109 +++++ .../serve/conversation/tests/test_service.py | 76 +++ dbgpt/serve/core/__init__.py | 3 +- dbgpt/serve/core/serve.py | 60 +++ dbgpt/serve/core/tests/conftest.py | 30 +- dbgpt/serve/prompt/api/schemas.py | 9 + dbgpt/serve/prompt/models/models.py | 3 + dbgpt/serve/prompt/serve.py | 48 +- .../default_serve_template/serve.py | 44 +- dbgpt/storage/metadata/__init__.py | 2 + dbgpt/storage/metadata/db_factory.py | 21 + dbgpt/storage/metadata/db_manager.py | 7 +- dbgpt/util/config_utils.py | 17 + dbgpt/util/openai_utils.py | 17 +- dbgpt/util/tracer/span_storage.py | 2 +- dbgpt/util/utils.py | 8 + examples/awel/data_analyst_assistant.py | 438 ++++++++++++++++++ .../agent_auto_plan_dialogue_example.ipynb | 119 ++--- setup.py | 1 + 63 files changed, 1996 insertions(+), 299 deletions(-) create mode 100644 dbgpt/serve/conversation/__init__.py create mode 100644 dbgpt/serve/conversation/api/__init__.py create mode 100644 dbgpt/serve/conversation/api/endpoints.py create mode 100644 dbgpt/serve/conversation/api/schemas.py create mode 100644 dbgpt/serve/conversation/config.py create mode 100644 dbgpt/serve/conversation/dependencies.py create mode 100644 dbgpt/serve/conversation/models/__init__.py create mode 100644 dbgpt/serve/conversation/models/models.py create mode 100644 dbgpt/serve/conversation/serve.py create mode 100644 dbgpt/serve/conversation/service/__init__.py create mode 100644 dbgpt/serve/conversation/service/service.py create mode 100644 dbgpt/serve/conversation/tests/__init__.py create mode 100644 dbgpt/serve/conversation/tests/test_endpoints.py create mode 100644 dbgpt/serve/conversation/tests/test_models.py create mode 100644 dbgpt/serve/conversation/tests/test_service.py create mode 100644 dbgpt/serve/core/serve.py create mode 100644 dbgpt/storage/metadata/db_factory.py create mode 100644 examples/awel/data_analyst_assistant.py diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql index b0cf2178a..3a5e1302a 100644 --- a/assets/schema/knowledge_management.sql +++ b/assets/schema/knowledge_management.sql @@ -175,6 +175,7 @@ CREATE TABLE IF NOT EXISTS `prompt_manage` `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description', `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', diff --git a/dbgpt/_private/pydantic.py b/dbgpt/_private/pydantic.py index df0cfaacb..c2f04928f 100644 --- a/dbgpt/_private/pydantic.py +++ b/dbgpt/_private/pydantic.py @@ -31,3 +31,13 @@ validator, PrivateAttr, ) + + +def model_to_json(model, **kwargs): + """Convert a pydantic model to json""" + if PYDANTIC_VERSION == 1: + return model.json(**kwargs) + else: + if "ensure_ascii" in kwargs: + del kwargs["ensure_ascii"] + return model.model_dump_json(**kwargs) diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index 6a2acd210..fde70098e 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -40,7 +40,7 @@ def server_init(param: "WebServerParameters", system_app: SystemApp): cfg = Config() cfg.SYSTEM_APP = system_app # Initialize db storage first - _initialize_db_storage(param) + _initialize_db_storage(param, system_app) # load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) @@ -86,12 +86,14 @@ def startup_event(wh): return startup_event -def _initialize_db_storage(param: "WebServerParameters"): +def _initialize_db_storage(param: "WebServerParameters", system_app: SystemApp): """Initialize the db storage. Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`. """ - _initialize_db(try_to_create_db=not param.disable_alembic_upgrade) + _initialize_db( + try_to_create_db=not param.disable_alembic_upgrade, system_app=system_app + ) def _migration_db_storage(param: "WebServerParameters"): @@ -114,7 +116,9 @@ def _migration_db_storage(param: "WebServerParameters"): _ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade) -def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: +def _initialize_db( + try_to_create_db: Optional[bool] = False, system_app: Optional[SystemApp] = None +) -> str: """Initialize the database Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`. @@ -147,7 +151,11 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: "pool_recycle": 3600, "pool_pre_ping": True, } - initialize_db(db_url, db_name, engine_args) + db = initialize_db(db_url, db_name, engine_args) + if system_app: + from dbgpt.storage.metadata import UnifiedDBManagerFactory + + system_app.register(UnifiedDBManagerFactory, db) return default_meta_data_path @@ -273,3 +281,9 @@ class WebServerParameters(BaseParameters): "help": "Whether to disable alembic to initialize and upgrade database metadata", }, ) + awel_dirs: Optional[str] = field( + default=None, + metadata={ + "help": "The directories to search awel files, split by `,`", + }, + ) diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index d477a19e4..4b31cc46e 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -46,9 +46,9 @@ def initialize_components( param, system_app, embedding_model_name, embedding_model_path ) _initialize_model_cache(system_app) - _initialize_awel(system_app) + _initialize_awel(system_app, param) # Register serve apps - register_serve_apps(system_app) + register_serve_apps(system_app, CFG) def _initialize_model_cache(system_app: SystemApp): @@ -64,8 +64,14 @@ def _initialize_model_cache(system_app: SystemApp): initialize_cache(system_app, storage_type, max_memory_mb, persist_dir) -def _initialize_awel(system_app: SystemApp): +def _initialize_awel(system_app: SystemApp, param: WebServerParameters): from dbgpt.core.awel import initialize_awel from dbgpt.configs.model_config import _DAG_DEFINITION_DIR - initialize_awel(system_app, _DAG_DEFINITION_DIR) + # Add default dag definition dir + dag_dirs = [_DAG_DEFINITION_DIR] + if param.awel_dirs: + dag_dirs += param.awel_dirs.strip().split(",") + dag_dirs = [x.strip() for x in dag_dirs] + + initialize_awel(system_app, dag_dirs) diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index c8e7893e9..694cc03dd 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -146,14 +146,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): mount_routers(app) model_start_listener = _create_model_start_listener(system_app) initialize_components(param, system_app, embedding_model_name, embedding_model_path) + system_app.on_init() - # Before start, after initialize_components - # TODO: initialize_worker_manager_in_client as a component register in system_app - system_app.before_start() # Migration db storage, so you db models must be imported before this _migration_db_storage(param) model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name) + # TODO: initialize_worker_manager_in_client as a component register in system_app if not param.light: print("Model Unified Deployment Mode!") if not param.remote_embedding: @@ -186,6 +185,9 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): CFG.SERVER_LIGHT_MODE = True mount_static_files(app) + + # Before start, after on_init + system_app.before_start() return param diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 50d098ad7..35fe26844 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -1,13 +1,28 @@ from dbgpt.component import SystemApp +from dbgpt._private.config import Config -def register_serve_apps(system_app: SystemApp): +def register_serve_apps(system_app: SystemApp, cfg: Config): """Register serve apps""" - from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX + system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE) + + # ################################ Prompt Serve Register Begin ###################################### + from dbgpt.serve.prompt.serve import ( + Serve as PromptServe, + SERVE_CONFIG_KEY_PREFIX as PROMPT_SERVE_CONFIG_KEY_PREFIX, + ) # Replace old prompt serve # Set config - system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt") - system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt") + system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt") + system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt") # Register serve app system_app.register(PromptServe, api_prefix="/prompt") + # ################################ Prompt Serve Register End ######################################## + + # ################################ Conversation Serve Register Begin ###################################### + from dbgpt.serve.conversation.serve import Serve as ConversationServe + + # Register serve app + system_app.register(ConversationServe) + # ################################ Conversation Serve Register End ######################################## diff --git a/dbgpt/app/openapi/api_v1/api_v1.py b/dbgpt/app/openapi/api_v1/api_v1.py index 6c25e6d2c..607642f8c 100644 --- a/dbgpt/app/openapi/api_v1/api_v1.py +++ b/dbgpt/app/openapi/api_v1/api_v1.py @@ -217,6 +217,10 @@ async def dialogue_list( model_name = item.get("model_name", CFG.LLM_MODEL) user_name = item.get("user_name") sys_code = item.get("sys_code") + if not item.get("messages"): + # Skip the empty messages + # TODO support new conversation and message mode + continue messages = json.loads(item.get("messages")) last_round = max(messages, key=lambda x: x["chat_order"]) diff --git a/dbgpt/component.py b/dbgpt/component.py index d10d500e0..466d01c80 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -17,10 +17,28 @@ class LifeCycle: - """This class defines hooks for lifecycle events of a component.""" + """This class defines hooks for lifecycle events of a component. + + Execution order of lifecycle hooks: + 1. on_init + 2. before_start(async_before_start) + 3. after_start(async_after_start) + 4. before_stop(async_before_stop) + """ + + def on_init(self): + """Called when the component is being initialized.""" + pass + + async def async_on_init(self): + """Asynchronous version of on_init.""" + pass def before_start(self): - """Called before the component starts.""" + """Called before the component starts. + + This method is called after the component has been initialized and before it is started. + """ pass async def async_before_start(self): @@ -59,6 +77,7 @@ class ComponentType(str, Enum): RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default" AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager" AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager" + UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory" @PublicAPI(stability="beta") @@ -177,6 +196,16 @@ def get_component( raise TypeError(f"Component {name} is not of type {component_type}") return component + def on_init(self): + """Invoke the on_init hooks for all registered components.""" + for _, v in self.components.items(): + v.on_init() + + async def async_on_init(self): + """Asynchronously invoke the on_init hooks for all registered components.""" + tasks = [v.async_on_init() for _, v in self.components.items()] + await asyncio.gather(*tasks) + def before_start(self): """Invoke the before_start hooks for all registered components.""" for _, v in self.components.items(): diff --git a/dbgpt/core/awel/__init__.py b/dbgpt/core/awel/__init__.py index 9331d2621..2fcd657cb 100644 --- a/dbgpt/core/awel/__init__.py +++ b/dbgpt/core/awel/__init__.py @@ -8,38 +8,36 @@ """ from typing import List, Optional -from dbgpt.component import SystemApp -from .dag.base import DAGContext, DAG +from dbgpt.component import SystemApp +from .dag.base import DAG, DAGContext from .operator.base import BaseOperator, WorkflowRunner from .operator.common_operator import ( - JoinOperator, - ReduceStreamOperator, - MapOperator, + BranchFunc, BranchOperator, InputOperator, - BranchFunc, + JoinOperator, + MapOperator, + ReduceStreamOperator, ) - from .operator.stream_operator import ( StreamifyAbsOperator, - UnstreamifyAbsOperator, TransformStreamAbsOperator, + UnstreamifyAbsOperator, ) - -from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource +from .runner.local_runner import DefaultWorkflowRunner +from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState from .task.task_impl import ( - SimpleInputSource, - SimpleCallDataInputSource, - DefaultTaskContext, DefaultInputContext, - SimpleTaskOutput, + DefaultTaskContext, + SimpleCallDataInputSource, + SimpleInputSource, SimpleStreamTaskOutput, + SimpleTaskOutput, _is_async_iterator, ) from .trigger.http_trigger import HttpTrigger -from .runner.local_runner import DefaultWorkflowRunner __all__ = [ "initialize_awel", @@ -73,16 +71,16 @@ ] -def initialize_awel(system_app: SystemApp, dag_filepath: str): - from .dag.dag_manager import DAGManager +def initialize_awel(system_app: SystemApp, dag_dirs: List[str]): from .dag.base import DAGVar - from .trigger.trigger_manager import DefaultTriggerManager + from .dag.dag_manager import DAGManager from .operator.base import initialize_runner + from .trigger.trigger_manager import DefaultTriggerManager DAGVar.set_current_system_app(system_app) system_app.register(DefaultTriggerManager) - dag_manager = DAGManager(system_app, dag_filepath) + dag_manager = DAGManager(system_app, dag_dirs) system_app.register_instance(dag_manager) initialize_runner(DefaultWorkflowRunner()) # Load all dags @@ -90,7 +88,11 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str): def setup_dev_environment( - dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555 + dags: List[DAG], + host: Optional[str] = "0.0.0.0", + port: Optional[int] = 5555, + logging_level: Optional[str] = None, + logger_filename: Optional[str] = None, ) -> None: """Setup a development environment for AWEL. @@ -98,9 +100,16 @@ def setup_dev_environment( """ import uvicorn from fastapi import FastAPI + from dbgpt.component import SystemApp - from .trigger.trigger_manager import DefaultTriggerManager + from dbgpt.util.utils import setup_logging + from .dag.base import DAGVar + from .trigger.trigger_manager import DefaultTriggerManager + + if not logger_filename: + logger_filename = "dbgpt_awel_dev.log" + setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename) app = FastAPI() system_app = SystemApp(app) diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index 5f182a97b..07b1db049 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -1,15 +1,16 @@ -from abc import ABC, abstractmethod -from typing import Optional, Dict, List, Sequence, Union, Any, Set -import uuid -import contextvars -import threading import asyncio +import contextvars import logging +import threading +import uuid +from abc import ABC, abstractmethod from collections import deque -from functools import cache from concurrent.futures import Executor +from functools import cache +from typing import Any, Dict, List, Optional, Sequence, Set, Union from dbgpt.component import SystemApp + from ..resource.base import ResourceGroup from ..task.base import TaskContext, TaskOutput @@ -502,6 +503,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): DAGVar.exit_dag() + def __repr__(self): + return f"DAG(dag_id={self.dag_id})" + def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]: nodes = set() diff --git a/dbgpt/core/awel/dag/dag_manager.py b/dbgpt/core/awel/dag/dag_manager.py index ff6148828..90214e934 100644 --- a/dbgpt/core/awel/dag/dag_manager.py +++ b/dbgpt/core/awel/dag/dag_manager.py @@ -1,8 +1,10 @@ -from typing import Dict, Optional import logging +from typing import Dict, List + from dbgpt.component import BaseComponent, ComponentType, SystemApp -from .loader import DAGLoader, LocalFileDAGLoader + from .base import DAG +from .loader import LocalFileDAGLoader logger = logging.getLogger(__name__) @@ -10,9 +12,9 @@ class DAGManager(BaseComponent): name = ComponentType.AWEL_DAG_MANAGER - def __init__(self, system_app: SystemApp, dag_filepath: str): + def __init__(self, system_app: SystemApp, dag_dirs: List[str]): super().__init__(system_app) - self.dag_loader = LocalFileDAGLoader(dag_filepath) + self.dag_loader = LocalFileDAGLoader(dag_dirs) self.system_app = system_app self.dag_map: Dict[str, DAG] = {} diff --git a/dbgpt/core/awel/dag/loader.py b/dbgpt/core/awel/dag/loader.py index 2eb89f8bc..325d4733a 100644 --- a/dbgpt/core/awel/dag/loader.py +++ b/dbgpt/core/awel/dag/loader.py @@ -1,10 +1,10 @@ -from abc import ABC, abstractmethod -from typing import List -import os import hashlib -import sys import logging +import os +import sys import traceback +from abc import ABC, abstractmethod +from typing import List from .base import DAG @@ -18,17 +18,19 @@ def load_dags(self) -> List[DAG]: class LocalFileDAGLoader(DAGLoader): - def __init__(self, filepath: str) -> None: - super().__init__() - self._filepath = filepath + def __init__(self, dag_dirs: List[str]) -> None: + self._dag_dirs = dag_dirs def load_dags(self) -> List[DAG]: - if not os.path.exists(self._filepath): - return [] - if os.path.isdir(self._filepath): - return _process_directory(self._filepath) - else: - return _process_file(self._filepath) + dags = [] + for filepath in self._dag_dirs: + if not os.path.exists(filepath): + continue + if os.path.isdir(filepath): + dags += _process_directory(filepath) + else: + dags += _process_file(filepath) + return dags def _process_directory(directory: str) -> List[DAG]: diff --git a/dbgpt/core/awel/dag/tests/test_dag.py b/dbgpt/core/awel/dag/tests/test_dag.py index 8a058f318..2eabbbe52 100644 --- a/dbgpt/core/awel/dag/tests/test_dag.py +++ b/dbgpt/core/awel/dag/tests/test_dag.py @@ -1,6 +1,8 @@ -import pytest -import threading import asyncio +import threading + +import pytest + from ..base import DAG, DAGVar diff --git a/dbgpt/core/awel/operator/base.py b/dbgpt/core/awel/operator/base.py index cc106dbf6..c07eabd14 100644 --- a/dbgpt/core/awel/operator/base.py +++ b/dbgpt/core/awel/operator/base.py @@ -1,32 +1,32 @@ -from abc import ABC, abstractmethod, ABCMeta - +import asyncio +import functools +from abc import ABC, ABCMeta, abstractmethod +from inspect import signature from types import FunctionType from typing import ( - List, - Generic, - TypeVar, - AsyncIterator, - Iterator, - Union, Any, + AsyncIterator, Dict, + Generic, + Iterator, + List, Optional, + TypeVar, + Union, cast, ) -import functools -from inspect import signature -import asyncio -from dbgpt.component import SystemApp, ComponentType + +from dbgpt.component import ComponentType, SystemApp from dbgpt.util.executor_utils import ( - ExecutorFactory, + AsyncToSyncIterator, + BlockingFunction, DefaultExecutorFactory, + ExecutorFactory, blocking_func_to_async, - BlockingFunction, - AsyncToSyncIterator, ) -from ..dag.base import DAGNode, DAGContext, DAGVar, DAG -from ..task.base import TaskOutput, OUT, T +from ..dag.base import DAG, DAGContext, DAGNode, DAGVar +from ..task.base import OUT, T, TaskOutput F = TypeVar("F", bound=FunctionType) diff --git a/dbgpt/core/awel/operator/common_operator.py b/dbgpt/core/awel/operator/common_operator.py index bd1199aa7..4fbb266f9 100644 --- a/dbgpt/core/awel/operator/common_operator.py +++ b/dbgpt/core/awel/operator/common_operator.py @@ -1,27 +1,19 @@ +import asyncio +import logging from typing import ( - Generic, - Dict, - List, - Union, - Callable, Any, AsyncIterator, Awaitable, + Callable, + Dict, + Generic, + List, Optional, + Union, ) -import asyncio -import logging from ..dag.base import DAGContext -from ..task.base import ( - TaskContext, - TaskOutput, - IN, - OUT, - InputContext, - InputSource, -) - +from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput from .base import BaseOperator logger = logging.getLogger(__name__) diff --git a/dbgpt/core/awel/operator/stream_operator.py b/dbgpt/core/awel/operator/stream_operator.py index 7de916a83..73eb49479 100644 --- a/dbgpt/core/awel/operator/stream_operator.py +++ b/dbgpt/core/awel/operator/stream_operator.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from typing import Generic, AsyncIterator -from ..task.base import OUT, IN, TaskOutput, TaskContext +from typing import AsyncIterator, Generic + from ..dag.base import DAGContext +from ..task.base import IN, OUT, TaskContext, TaskOutput from .base import BaseOperator diff --git a/dbgpt/core/awel/runner/job_manager.py b/dbgpt/core/awel/runner/job_manager.py index 44c1af01b..b015f85e7 100644 --- a/dbgpt/core/awel/runner/job_manager.py +++ b/dbgpt/core/awel/runner/job_manager.py @@ -1,10 +1,10 @@ import asyncio -from typing import List, Set, Optional, Dict -import uuid import logging -from ..dag.base import DAG, DAGLifecycle +import uuid +from typing import Dict, List, Optional, Set -from ..operator.base import BaseOperator, CALL_DATA +from ..dag.base import DAG, DAGLifecycle +from ..operator.base import CALL_DATA, BaseOperator logger = logging.getLogger(__name__) diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index 1bc8fdc9d..8ad3f417c 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -1,9 +1,10 @@ -from typing import Dict, Optional, Set, List import logging +from typing import Dict, List, Optional, Set from dbgpt.component import SystemApp + from ..dag.base import DAGContext, DAGVar -from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA +from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator from ..task.base import TaskContext, TaskState from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput diff --git a/dbgpt/core/awel/task/base.py b/dbgpt/core/awel/task/base.py index 58f3863b8..603bc8609 100644 --- a/dbgpt/core/awel/task/base.py +++ b/dbgpt/core/awel/task/base.py @@ -1,15 +1,15 @@ from abc import ABC, abstractmethod from enum import Enum from typing import ( - TypeVar, - Generic, - Optional, + Any, AsyncIterator, - Union, Callable, - Any, Dict, + Generic, List, + Optional, + TypeVar, + Union, ) IN = TypeVar("IN") diff --git a/dbgpt/core/awel/task/task_impl.py b/dbgpt/core/awel/task/task_impl.py index a7bf542d5..5f113aeec 100644 --- a/dbgpt/core/awel/task/task_impl.py +++ b/dbgpt/core/awel/task/task_impl.py @@ -1,22 +1,22 @@ +import asyncio +import logging from abc import ABC, abstractmethod from typing import ( + Any, + AsyncIterator, Callable, Coroutine, + Dict, + Generic, Iterator, - AsyncIterator, List, - Generic, - TypeVar, - Any, + Optional, Tuple, - Dict, + TypeVar, Union, - Optional, ) -import asyncio -import logging -from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T +from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState logger = logging.getLogger(__name__) diff --git a/dbgpt/core/awel/tests/conftest.py b/dbgpt/core/awel/tests/conftest.py index 2279cceba..a6fbb1c76 100644 --- a/dbgpt/core/awel/tests/conftest.py +++ b/dbgpt/core/awel/tests/conftest.py @@ -1,14 +1,16 @@ +from contextlib import asynccontextmanager, contextmanager +from typing import AsyncIterator, List + import pytest import pytest_asyncio -from typing import AsyncIterator, List -from contextlib import contextmanager, asynccontextmanager + from .. import ( - WorkflowRunner, - InputOperator, DAGContext, - TaskState, DefaultWorkflowRunner, + InputOperator, SimpleInputSource, + TaskState, + WorkflowRunner, ) from ..task.task_impl import _is_async_iterator diff --git a/dbgpt/core/awel/tests/test_http_operator.py b/dbgpt/core/awel/tests/test_http_operator.py index c57e70fe1..adfe05e47 100644 --- a/dbgpt/core/awel/tests/test_http_operator.py +++ b/dbgpt/core/awel/tests/test_http_operator.py @@ -1,24 +1,26 @@ -import pytest from typing import List + +import pytest + from .. import ( DAG, - WorkflowRunner, + BranchOperator, DAGContext, - TaskState, InputOperator, - MapOperator, JoinOperator, - BranchOperator, + MapOperator, ReduceStreamOperator, SimpleInputSource, + TaskState, + WorkflowRunner, ) from .conftest import ( - runner, + _is_async_iterator, input_node, input_nodes, + runner, stream_input_node, stream_input_nodes, - _is_async_iterator, ) diff --git a/dbgpt/core/awel/tests/test_run_dag.py b/dbgpt/core/awel/tests/test_run_dag.py index c0ea8e7ad..f797c6ccc 100644 --- a/dbgpt/core/awel/tests/test_run_dag.py +++ b/dbgpt/core/awel/tests/test_run_dag.py @@ -1,24 +1,26 @@ -import pytest from typing import List + +import pytest + from .. import ( DAG, - WorkflowRunner, + BranchOperator, DAGContext, - TaskState, InputOperator, - MapOperator, JoinOperator, - BranchOperator, + MapOperator, ReduceStreamOperator, SimpleInputSource, + TaskState, + WorkflowRunner, ) from .conftest import ( - runner, + _is_async_iterator, input_node, input_nodes, + runner, stream_input_node, stream_input_nodes, - _is_async_iterator, ) diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 192165f10..33a6e3ad9 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable +import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union + from starlette.requests import Request from starlette.responses import Response + from dbgpt._private.pydantic import BaseModel -import logging -from .base import Trigger from ..dag.base import DAG from ..operator.base import BaseOperator +from .base import Trigger if TYPE_CHECKING: from fastapi import APIRouter, FastAPI diff --git a/dbgpt/core/awel/trigger/trigger_manager.py b/dbgpt/core/awel/trigger/trigger_manager.py index efa12f7fb..95b4b89ab 100644 --- a/dbgpt/core/awel/trigger/trigger_manager.py +++ b/dbgpt/core/awel/trigger/trigger_manager.py @@ -1,11 +1,11 @@ -from abc import ABC, abstractmethod -from typing import Any, TYPE_CHECKING, Optional import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from fastapi import APIRouter -from dbgpt.component import SystemApp, BaseComponent, ComponentType +from dbgpt.component import BaseComponent, ComponentType, SystemApp logger = logging.getLogger(__name__) diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index ebc6088c2..8ba40b333 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -114,6 +114,9 @@ class ModelRequestContext: span_id: Optional[str] = None """The span id of the model inference.""" + chat_mode: Optional[str] = None + """The chat mode of the model inference.""" + extra: Optional[Dict[str, Any]] = field(default_factory=dict) """The extra information of the model inference.""" @@ -195,7 +198,13 @@ def to_dict(self) -> Dict[str, Any]: # Skip None fields return {k: v for k, v in asdict(new_reqeust).items() if v} - def _get_messages(self) -> List[ModelMessage]: + def get_messages(self) -> List[ModelMessage]: + """Get the messages. + + If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage. + Returns: + List[ModelMessage]: The messages. + """ return list( map( lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m), @@ -209,7 +218,7 @@ def get_single_user_message(self) -> Optional[ModelMessage]: Returns: Optional[ModelMessage]: The single user message. """ - messages = self._get_messages() + messages = self.get_messages() if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN: raise ValueError("The messages is not a single user message") return messages[0] diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 005452db2..6b0204817 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from dbgpt._private.pydantic import BaseModel, Field from dbgpt.core.awel import MapOperator @@ -176,6 +176,22 @@ def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: def build_human_message(content: str) -> "ModelMessage": return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) + @staticmethod + def get_printable_message(messages: List["ModelMessage"]) -> str: + """Get the printable message""" + str_msg = "" + for message in messages: + curr_message = ( + f"(Round {message.round_index}) {message.role}: {message.content} " + ) + str_msg += curr_message.rstrip() + "\n" + + return str_msg + + +_SingleRoundMessage = List[ModelMessage] +_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]] + def _message_to_dict(message: BaseMessage) -> Dict: return message.to_dict() diff --git a/dbgpt/core/interface/operator/message_operator.py b/dbgpt/core/interface/operator/message_operator.py index d775e8401..8e5b0131d 100644 --- a/dbgpt/core/interface/operator/message_operator.py +++ b/dbgpt/core/interface/operator/message_operator.py @@ -5,6 +5,7 @@ from dbgpt.core import ( MessageStorageItem, ModelMessage, + ModelMessageRoleType, ModelOutput, ModelRequest, ModelRequestContext, @@ -12,6 +13,7 @@ StorageInterface, ) from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator +from dbgpt.core.interface.message import _MultiRoundMessageMapper class BaseConversationOperator(BaseOperator, ABC): @@ -24,7 +26,7 @@ def __init__( self, storage: Optional[StorageInterface[StorageConversation, Any]] = None, message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self._storage = storage @@ -88,7 +90,7 @@ def __init__( self, storage: Optional[StorageInterface[StorageConversation, Any]] = None, message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, - **kwargs + **kwargs, ): super().__init__(storage=storage, message_storage=message_storage) MapOperator.__init__(self, **kwargs) @@ -109,7 +111,7 @@ async def map(self, input_value: ModelRequest) -> ModelRequest: if not input_value.context.extra: input_value.context.extra = {} - chat_mode = input_value.context.extra.get("chat_mode") + chat_mode = input_value.context.chat_mode # Create a new storage conversation, this will load the conversation from storage, so we must do this async storage_conv: StorageConversation = await self.blocking_func_to_async( @@ -121,11 +123,8 @@ async def map(self, input_value: ModelRequest) -> ModelRequest: conv_storage=self.storage, message_storage=self.message_storage, ) - # The input message must be a single user message - single_human_message: ModelMessage = input_value.get_single_user_message() - storage_conv.start_new_round() - storage_conv.add_user_message(single_human_message.content) - + input_messages = input_value.get_messages() + await self.save_to_storage(storage_conv, input_messages) # Get all messages from current storage conversation, and overwrite the input value messages: List[ModelMessage] = storage_conv.get_model_messages() input_value.messages = messages @@ -139,6 +138,42 @@ async def map(self, input_value: ModelRequest) -> ModelRequest: ) return input_value + async def save_to_storage( + self, storage_conv: StorageConversation, input_messages: List[ModelMessage] + ) -> None: + """Save the messages to storage. + + Args: + storage_conv (StorageConversation): The storage conversation. + input_messages (List[ModelMessage]): The input messages. + """ + # check first + self.check_messages(input_messages) + storage_conv.start_new_round() + for message in input_messages: + if message.role == ModelMessageRoleType.HUMAN: + storage_conv.add_user_message(message.content) + else: + storage_conv.add_system_message(message.content) + + def check_messages(self, messages: List[ModelMessage]) -> None: + """Check the messages. + + Args: + messages (List[ModelMessage]): The messages. + + Raises: + ValueError: If the messages is empty. + """ + if not messages: + raise ValueError("Input messages is empty") + for message in messages: + if message.role not in [ + ModelMessageRoleType.HUMAN, + ModelMessageRoleType.SYSTEM, + ]: + raise ValueError(f"Message role {message.role} is not supported") + async def after_dag_end(self): """The callback after DAG end""" # Save the storage conversation to storage after the whole DAG finished @@ -198,8 +233,9 @@ async def transform_stream( class ConversationMapperOperator( BaseConversationOperator, MapOperator[ModelRequest, ModelRequest] ): - def __init__(self, **kwargs): + def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs): MapOperator.__init__(self, **kwargs) + self._message_mapper = message_mapper async def map(self, input_value: ModelRequest) -> ModelRequest: """Map the input value to a ModelRequest. @@ -211,12 +247,12 @@ async def map(self, input_value: ModelRequest) -> ModelRequest: ModelRequest: The mapped ModelRequest. """ input_value = input_value.copy() - messages: List[ModelMessage] = await self.map_messages(input_value.messages) + messages: List[ModelMessage] = self.map_messages(input_value.messages) # Overwrite the input value input_value.messages = messages return input_value - async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]: + def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]: """Map the input messages to a list of ModelMessage. Args: @@ -225,7 +261,73 @@ async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage] Returns: List[ModelMessage]: The mapped ModelMessage. """ - return messages + messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round( + messages + ) + message_mapper = self._message_mapper or self.map_multi_round_messages + return message_mapper(messages_by_round) + + def map_multi_round_messages( + self, messages_by_round: List[List[ModelMessage]] + ) -> List[ModelMessage]: + """Map multi round messages to a list of ModelMessage + + By default, just merge all multi round messages to a list of ModelMessage according origin order. + And you can overwrite this method to implement your own logic. + + Examples: + + Merge multi round messages to a list of ModelMessage according origin order. + + .. code-block:: python + + import asyncio + from dbgpt.core.operator import ConversationMapperOperator + messages_by_round = [ + [ + ModelMessage(role="human", content="Hi", round_index=1), + ModelMessage(role="ai", content="Hello!", round_index=1), + ], + [ + ModelMessage(role="system", content="Error 404", round_index=2), + ModelMessage(role="human", content="What's the error?", round_index=2), + ModelMessage(role="ai", content="Just a joke.", round_index=2), + ], + [ + ModelMessage(role="human", content="Funny!", round_index=3), + ], + ] + operator = ConversationMapperOperator() + messages = operator.map_multi_round_messages(messages_by_round) + assert messages == [ + ModelMessage(role="human", content="Hi", round_index=1), + ModelMessage(role="ai", content="Hello!", round_index=1), + ModelMessage(role="system", content="Error 404", round_index=2), + ModelMessage(role="human", content="What's the error?", round_index=2), + ModelMessage(role="ai", content="Just a joke.", round_index=2), + ModelMessage(role="human", content="Funny!", round_index=3), + ] + + Map multi round messages to a list of ModelMessage just keep the last one round. + + .. code-block:: python + + class MyMapper(ConversationMapperOperator): + def __init__(self, **kwargs): + super().__init__(**kwargs) + def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]: + return messages_by_round[-1] + operator = MyMapper() + messages = operator.map_multi_round_messages(messages_by_round) + assert messages == [ + ModelMessage(role="human", content="Funny!", round_index=3), + ] + + Args: + """ + # Just merge and return + # e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6] + return sum(messages_by_round, []) def _split_messages_by_round( self, messages: List[ModelMessage] @@ -236,7 +338,7 @@ def _split_messages_by_round( messages (List[ModelMessage]): The input messages. Returns: - List[List[ModelMessage]]: The splitted messages. + List[List[ModelMessage]]: The split messages. """ messages_by_round: List[List[ModelMessage]] = [] last_round_index = 0 @@ -263,15 +365,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator): .. code-block:: python - import asyncio from dbgpt.core import ModelMessage from dbgpt.core.operator import BufferedConversationMapperOperator # No history messages = [ModelMessage(role="human", content="Hello", round_index=1)] operator = BufferedConversationMapperOperator(last_k_round=1) - messages = asyncio.run(operator.map_messages(messages)) - assert messages == [ModelMessage(role="human", content="Hello", round_index=1)] + assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)] Transform with history messages @@ -287,10 +387,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator): ModelMessage(role="human", content="Funny!", round_index=3), ] operator = BufferedConversationMapperOperator(last_k_round=1) - messages = asyncio.run(operator.map_messages(messages)) # Just keep the last one round, so the first round messages will be removed # Note: The round index 3 is not a complete round - assert messages == [ + assert operator.map_messages(messages) == [ ModelMessage(role="system", content="Error 404", round_index=2), ModelMessage(role="human", content="What's the error?", round_index=2), ModelMessage(role="ai", content="Just a joke.", round_index=2), @@ -298,24 +397,42 @@ class BufferedConversationMapperOperator(ConversationMapperOperator): ] """ - def __init__(self, last_k_round: Optional[int] = 2, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + last_k_round: Optional[int] = 2, + message_mapper: _MultiRoundMessageMapper = None, + **kwargs, + ): self._last_k_round = last_k_round + if message_mapper: - async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]: - """Map the input messages to a list of ModelMessage. + def new_message_mapper( + messages_by_round: List[List[ModelMessage]], + ) -> List[ModelMessage]: + # Apply keep k round messages first, then apply the custom message mapper + messages_by_round = self._keep_last_round_messages(messages_by_round) + return message_mapper(messages_by_round) + + else: + + def new_message_mapper( + messages_by_round: List[List[ModelMessage]], + ) -> List[ModelMessage]: + messages_by_round = self._keep_last_round_messages(messages_by_round) + return sum(messages_by_round, []) + + super().__init__(new_message_mapper, **kwargs) + + def _keep_last_round_messages( + self, messages_by_round: List[List[ModelMessage]] + ) -> List[List[ModelMessage]]: + """Keep the last k round messages. Args: - messages (List[ModelMessage]): The input messages. + messages_by_round (List[List[ModelMessage]]): The messages by round. Returns: - List[ModelMessage]: The mapped ModelMessage. + List[List[ModelMessage]]: The latest round messages. """ - messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round( - messages - ) - # Get the last k round messages index = self._last_k_round + 1 - messages_by_round = messages_by_round[-index:] - messages: List[ModelMessage] = sum(messages_by_round, []) - return messages + return messages_by_round[-index:] diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index e55869953..f584eda13 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -169,9 +169,7 @@ def __post_init__(self): def to_prompt_template(self) -> PromptTemplate: """Convert the storage prompt template to a prompt template.""" input_variables = ( - None - if not self.input_variables - else self.input_variables.strip().split(",") + [] if not self.input_variables else self.input_variables.strip().split(",") ) return PromptTemplate( input_variables=input_variables, @@ -458,6 +456,33 @@ def save(self, prompt_template: PromptTemplate, prompt_name: str, **kwargs) -> N ) self.storage.save(storage_prompt_template) + def query_or_save( + self, prompt_template: PromptTemplate, prompt_name: str, **kwargs + ) -> StoragePromptTemplate: + """Query a prompt template from storage, if not found, save it. + + Args: + prompt_template (PromptTemplate): The prompt template to save. + prompt_name (str): The name of the prompt template. + kwargs (Dict): Other params to build the storage prompt template. + More details in :meth:`~StoragePromptTemplate.from_prompt_template`. + + Returns: + StoragePromptTemplate: The storage prompt template. + """ + storage_prompt_template = StoragePromptTemplate.from_prompt_template( + prompt_template, prompt_name, **kwargs + ) + exist_prompt_template = self.storage.load( + storage_prompt_template.identifier, StoragePromptTemplate + ) + if exist_prompt_template: + return exist_prompt_template + self.save(prompt_template, prompt_name, **kwargs) + return self.storage.load( + storage_prompt_template.identifier, StoragePromptTemplate + ) + def list(self, **kwargs) -> List[StoragePromptTemplate]: """List prompt templates from storage. diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py index 22263051a..43fa62646 100644 --- a/dbgpt/model/cluster/apiserver/api.py +++ b/dbgpt/model/cluster/apiserver/api.py @@ -16,7 +16,6 @@ from fastapi.responses import StreamingResponse, JSONResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer -from pydantic import BaseSettings from fastchat.protocol.openai_api_protocol import ( ChatCompletionResponse, @@ -42,6 +41,7 @@ from dbgpt.model.cluster import ModelRegistry from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory from dbgpt.util.utils import setup_logging +from dbgpt._private.pydantic import BaseModel logger = logging.getLogger(__name__) @@ -52,7 +52,7 @@ def __init__(self, code: int, message: str): self.message = message -class APISettings(BaseSettings): +class APISettings(BaseModel): api_keys: Optional[List[str]] = None diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index 32d6e29e0..02333b2b1 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -24,6 +24,7 @@ from dbgpt.core.interface.llm import ModelOutput, ModelRequest from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory +from dbgpt._private.pydantic import model_to_json if TYPE_CHECKING: import httpx @@ -175,6 +176,9 @@ def _build_request( async def generate(self, request: ModelRequest) -> ModelOutput: messages = request.to_openai_messages() payload = self._build_request(request) + logger.info( + f"Send request to openai, payload: {payload}\n\n messages:\n{messages}" + ) try: chat_completion = await self.client.chat.completions.create( messages=messages, **payload @@ -193,6 +197,9 @@ async def generate_stream( ) -> AsyncIterator[ModelOutput]: messages = request.to_openai_messages() payload = self._build_request(request, True) + logger.info( + f"Send request to openai, payload: {payload}\n\n messages:\n{messages}" + ) try: chat_completion = await self.client.chat.completions.create( messages=messages, **payload @@ -321,7 +328,7 @@ async def _to_openai_stream( chunk = ChatCompletionStreamResponse( id=id, choices=[choice_data], model=model or "" ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield f"data: {model_to_json(chunk, exclude_unset=True, ensure_ascii=False)}\n\n" previous_text = "" finish_stream_events = [] @@ -356,7 +363,7 @@ async def _to_openai_stream( if model_output.finish_reason is not None: finish_stream_events.append(chunk) continue - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield f"data: {model_to_json(chunk, exclude_unset=True, ensure_ascii=False)}\n\n" for finish_chunk in finish_stream_events: - yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield f"data: {model_to_json(finish_chunk, exclude_none=True, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" diff --git a/dbgpt/serve/conversation/__init__.py b/dbgpt/serve/conversation/__init__.py new file mode 100644 index 000000000..3bc6a0ef3 --- /dev/null +++ b/dbgpt/serve/conversation/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve conversation` diff --git a/dbgpt/serve/conversation/api/__init__.py b/dbgpt/serve/conversation/api/__init__.py new file mode 100644 index 000000000..3bc6a0ef3 --- /dev/null +++ b/dbgpt/serve/conversation/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve conversation` diff --git a/dbgpt/serve/conversation/api/endpoints.py b/dbgpt/serve/conversation/api/endpoints.py new file mode 100644 index 000000000..10398e7e2 --- /dev/null +++ b/dbgpt/serve/conversation/api/endpoints.py @@ -0,0 +1,176 @@ +from typing import Optional, List +from functools import cache +from fastapi import APIRouter, Depends, Query, HTTPException +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result +from dbgpt.util import PaginationResult +from .schemas import ServeRequest, ServerResponse +from ..service.service import Service +from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME + +router = APIRouter() + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key } + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) +async def create( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Create a new Conversation entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.create(request)) + + +@router.put( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) +async def update( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Update a Conversation entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.update(request)) + + +@router.post( + "/query", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) +async def query( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Query Conversation entities + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get(request)) + + +@router.post( + "/query_page", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) +async def query_page( + request: ServeRequest, + page: Optional[int] = Query(default=1, description="current page"), + page_size: Optional[int] = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query Conversation entities + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get_list_by_page(request, page, page_size)) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/conversation/api/schemas.py b/dbgpt/serve/conversation/api/schemas.py new file mode 100644 index 000000000..353be0b13 --- /dev/null +++ b/dbgpt/serve/conversation/api/schemas.py @@ -0,0 +1,20 @@ +# Define your Pydantic schemas here +from dbgpt._private.pydantic import BaseModel, Field +from ..config import SERVE_APP_NAME_HUMP + + +class ServeRequest(BaseModel): + """Conversation request model""" + + # TODO define your own fields here + + class Config: + title = f"ServeRequest for {SERVE_APP_NAME_HUMP}" + + +class ServerResponse(BaseModel): + """Conversation response model""" + + # TODO define your own fields here + class Config: + title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" diff --git a/dbgpt/serve/conversation/config.py b/dbgpt/serve/conversation/config.py new file mode 100644 index 000000000..3809846c5 --- /dev/null +++ b/dbgpt/serve/conversation/config.py @@ -0,0 +1,23 @@ +from typing import Optional +from dataclasses import dataclass, field + +from dbgpt.serve.core import BaseServeConfig + + +APP_NAME = "conversation" +SERVE_APP_NAME = "dbgpt_serve_conversation" +SERVE_APP_NAME_HUMP = "dbgpt_serve_Conversation" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.conversation." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_conversation" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) diff --git a/dbgpt/serve/conversation/dependencies.py b/dbgpt/serve/conversation/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/conversation/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/conversation/models/__init__.py b/dbgpt/serve/conversation/models/__init__.py new file mode 100644 index 000000000..3bc6a0ef3 --- /dev/null +++ b/dbgpt/serve/conversation/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve conversation` diff --git a/dbgpt/serve/conversation/models/models.py b/dbgpt/serve/conversation/models/models.py new file mode 100644 index 000000000..6328bafd6 --- /dev/null +++ b/dbgpt/serve/conversation/models/models.py @@ -0,0 +1,68 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" +from typing import Union, Any, Dict +from datetime import datetime +from sqlalchemy import Column, Integer, String, Index, Text, DateTime +from dbgpt.storage.metadata import Model, BaseDao, db +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig, SERVER_APP_TABLE_NAME + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + id = Column(Integer, primary_key=True, comment="Auto increment id") + + # TODO: define your own fields here + + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + def __repr__(self): + return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for Conversation""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = request.dict() if isinstance(request, ServeRequest) else request + entity = ServeEntity(**request_dict) + # TODO implement your own logic here, transfer the request_dict to an entity + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + # TODO implement your own logic here, transfer the entity to a request + return ServeRequest() + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + # TODO implement your own logic here, transfer the entity to a response + return ServerResponse() diff --git a/dbgpt/serve/conversation/serve.py b/dbgpt/serve/conversation/serve.py new file mode 100644 index 000000000..1578dd44e --- /dev/null +++ b/dbgpt/serve/conversation/serve.py @@ -0,0 +1,99 @@ +from typing import List, Optional, Union +import logging +from dbgpt.component import SystemApp +from sqlalchemy import URL +from dbgpt.core import StorageInterface +from dbgpt.storage.metadata import DatabaseManager +from dbgpt.serve.core import BaseServe + +from .config import ( + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + APP_NAME, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT + + Message DB-GPT conversation history and provide API for other components to access. + + TODO: Move some Http API in app to this component. + """ + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + self._conv_storage = None + self._message_storage = None + + @property + def conv_storage(self) -> StorageInterface: + return self._conv_storage + + @property + def message_storage(self) -> StorageInterface: + return self._message_storage + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # Load DB Model + from dbgpt.storage.chat_history.chat_history_db import ( + ChatHistoryEntity, + ChatHistoryMessageEntity, + ) + + def before_start(self): + """Called before the start of the application.""" + # TODO: Your code here + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + from dbgpt.storage.chat_history.chat_history_db import ( + ChatHistoryEntity, + ChatHistoryMessageEntity, + ) + from dbgpt.storage.chat_history.storage_adapter import ( + DBStorageConversationItemAdapter, + DBMessageStorageItemAdapter, + ) + + self._db_manager = self.create_or_get_db_manager() + + self._conv_storage = SQLAlchemyStorage( + self._db_manager, + ChatHistoryEntity, + DBStorageConversationItemAdapter(), + JsonSerializer(), + ) + self._message_storage = SQLAlchemyStorage( + self._db_manager, + ChatHistoryMessageEntity, + DBMessageStorageItemAdapter(), + JsonSerializer(), + ) diff --git a/dbgpt/serve/conversation/service/__init__.py b/dbgpt/serve/conversation/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/conversation/service/service.py b/dbgpt/serve/conversation/service/service.py new file mode 100644 index 000000000..d346822f6 --- /dev/null +++ b/dbgpt/serve/conversation/service/service.py @@ -0,0 +1,116 @@ +from typing import Optional, List +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.serve.core import BaseService +from ..models.models import ServeDao, ServeEntity +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for Conversation""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def update(self, request: ServeRequest) -> ServerResponse: + """Update a Conversation entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + # "id": request.id + } + return self.dao.update(query_request, update_request=request) + + def get(self, request: ServeRequest) -> Optional[ServerResponse]: + """Get a Conversation entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_one(query_request) + + def delete(self, request: ServeRequest) -> None: + """Delete a Conversation entity + + Args: + request (ServeRequest): The request + """ + + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + # "id": request.id + } + self.dao.delete(query_request) + + def get_list(self, request: ServeRequest) -> List[ServerResponse]: + """Get a list of Conversation entities + + Args: + request (ServeRequest): The request + + Returns: + List[ServerResponse]: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_list(query_request) + + def get_list_by_page( + self, request: ServeRequest, page: int, page_size: int + ) -> PaginationResult[ServerResponse]: + """Get a list of Conversation entities by page + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + + Returns: + List[ServerResponse]: The response + """ + query_request = request + return self.dao.get_list_page(query_request, page, page_size) diff --git a/dbgpt/serve/conversation/tests/__init__.py b/dbgpt/serve/conversation/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/conversation/tests/test_endpoints.py b/dbgpt/serve/conversation/tests/test_endpoints.py new file mode 100644 index 000000000..79f35c7a0 --- /dev/null +++ b/dbgpt/serve/conversation/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from httpx import AsyncClient + +from fastapi import FastAPI +from dbgpt.component import SystemApp +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult +from ..config import SERVE_CONFIG_KEY_PREFIX +from ..api.endpoints import router, init_endpoints +from ..api.schemas import ServeRequest, ServerResponse + +from dbgpt.serve.core.tests.conftest import client, asystem_app + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/conversation/tests/test_models.py b/dbgpt/serve/conversation/tests/test_models.py new file mode 100644 index 000000000..c065909b2 --- /dev/null +++ b/dbgpt/serve/conversation/tests/test_models.py @@ -0,0 +1,109 @@ +from typing import List +import pytest +from dbgpt.storage.metadata import db +from ..config import ServeConfig +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity, ServeDao + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + # TODO : build your server config + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + # TODO: implement your test case + with db.session() as session: + db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + assert db_entity.id == entity.id + + +def test_entity_unique_key(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_get(default_entity_dict): + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + db_entity: ServeEntity = ServeEntity.get(entity.id) + assert db_entity.id == entity.id + # TODO: implement your test case + + +def test_entity_update(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_delete(default_entity_dict): + # TODO: implement your test case + entity: ServeEntity = ServeEntity.create(**default_entity_dict) + entity.delete() + db_entity: ServeEntity = ServeEntity.get(entity.id) + assert db_entity is None + + +def test_entity_all(): + # TODO: implement your test case + pass + + +def test_dao_create(dao, default_entity_dict): + # TODO: implement your test case + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + assert res is not None + + +def test_dao_get_one(dao, default_entity_dict): + # TODO: implement your test case + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + + +def test_get_dao_get_list(dao): + # TODO: implement your test case + pass + + +def test_dao_update(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_delete(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_list_page(dao): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/conversation/tests/test_service.py b/dbgpt/serve/conversation/tests/test_service.py new file mode 100644 index 000000000..003286c82 --- /dev/null +++ b/dbgpt/serve/conversation/tests/test_service.py @@ -0,0 +1,76 @@ +from typing import List +import pytest +from dbgpt.component import SystemApp +from dbgpt.storage.metadata import db +from dbgpt.serve.core.tests.conftest import system_app + +from ..models.models import ServeEntity +from ..api.schemas import ServeRequest, ServerResponse +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/core/__init__.py b/dbgpt/serve/core/__init__.py index 36a1900e9..23275525c 100644 --- a/dbgpt/serve/core/__init__.py +++ b/dbgpt/serve/core/__init__.py @@ -1,5 +1,6 @@ from dbgpt.serve.core.schemas import Result from dbgpt.serve.core.config import BaseServeConfig from dbgpt.serve.core.service import BaseService +from dbgpt.serve.core.serve import BaseServe -__ALL__ = ["Result", "BaseServeConfig", "BaseService"] +__ALL__ = ["Result", "BaseServeConfig", "BaseService", "BaseServe"] diff --git a/dbgpt/serve/core/serve.py b/dbgpt/serve/core/serve.py new file mode 100644 index 000000000..4d3983f16 --- /dev/null +++ b/dbgpt/serve/core/serve.py @@ -0,0 +1,60 @@ +from abc import ABC +from typing import Optional, Union, List +import logging +from dbgpt.component import BaseComponent, SystemApp, ComponentType +from sqlalchemy import URL +from dbgpt.storage.metadata import DatabaseManager + +logger = logging.getLogger(__name__) + + +class BaseServe(BaseComponent, ABC): + """Base serve component for DB-GPT""" + + name = "dbgpt_serve_base" + + def __init__( + self, + system_app: SystemApp, + api_prefix: str, + api_tags: List[str], + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + self._system_app = system_app + self._api_prefix = api_prefix + self._api_tags = api_tags + self._db_url_or_db = db_url_or_db + self._try_create_tables = try_create_tables + self._not_create_table = True + self._app_has_initiated = False + + def create_or_get_db_manager(self) -> DatabaseManager: + """Create or get the database manager. + This method must be called after the application is initialized + + Returns: + DatabaseManager: The database manager + """ + from dbgpt.storage.metadata import Model, db, UnifiedDBManagerFactory + + # If you need to use the database, you can get the database manager here + db_manager_factory: UnifiedDBManagerFactory = self._system_app.get_component( + ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY, + UnifiedDBManagerFactory, + default_component=None, + ) + if db_manager_factory is not None and db_manager_factory.create(): + init_db = db_manager_factory.create() + else: + init_db = self._db_url_or_db or db + init_db = DatabaseManager.build_from(init_db, base=Model) + + if self._try_create_tables and self._not_create_table: + try: + init_db.create_all() + except Exception as e: + logger.warning(f"Failed to create tables: {e}") + finally: + self._not_create_table = False + return init_db diff --git a/dbgpt/serve/core/tests/conftest.py b/dbgpt/serve/core/tests/conftest.py index 9c8c7c77e..090f3bf95 100644 --- a/dbgpt/serve/core/tests/conftest.py +++ b/dbgpt/serve/core/tests/conftest.py @@ -8,15 +8,6 @@ from dbgpt.component import SystemApp from dbgpt.util import AppConfig -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], - allow_headers=["*"], -) - def create_system_app(param: Dict) -> SystemApp: app_config = param.get("app_config", {}) @@ -24,7 +15,17 @@ def create_system_app(param: Dict) -> SystemApp: app_config = AppConfig(configs=app_config) elif not isinstance(app_config, AppConfig): raise RuntimeError("app_config must be AppConfig or dict") - return SystemApp(app, app_config) + + test_app = FastAPI() + test_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + return SystemApp(test_app, app_config) @pytest_asyncio.fixture @@ -51,9 +52,12 @@ async def client(request, asystem_app: SystemApp): del param["api_keys"] if client_api_key: headers["Authorization"] = "Bearer " + client_api_key - async with AsyncClient(app=app, base_url=base_url, headers=headers) as client: + + test_app = asystem_app.app + + async with AsyncClient(app=test_app, base_url=base_url, headers=headers) as client: for router in routers: - app.include_router(router) + test_app.include_router(router) if app_caller: - app_caller(app, asystem_app) + app_caller(test_app, asystem_app) yield client diff --git a/dbgpt/serve/prompt/api/schemas.py b/dbgpt/serve/prompt/api/schemas.py index 7cdbc7a1b..6d8d67924 100644 --- a/dbgpt/serve/prompt/api/schemas.py +++ b/dbgpt/serve/prompt/api/schemas.py @@ -43,6 +43,15 @@ class Config: "You are a data analysis expert.", ], ) + prompt_desc: Optional[str] = Field( + None, + description="The prompt description.", + examples=[ + "This is a prompt for code assistant.", + "This is a prompt for joker.", + "This is a prompt for data analysis expert.", + ], + ) user_name: Optional[str] = Field( None, diff --git a/dbgpt/serve/prompt/models/models.py b/dbgpt/serve/prompt/models/models.py index 8095a5014..61346db47 100644 --- a/dbgpt/serve/prompt/models/models.py +++ b/dbgpt/serve/prompt/models/models.py @@ -48,6 +48,7 @@ class ServeEntity(Model): default="f-string", comment="Prompt format(eg: f-string, jinja2)", ) + prompt_desc = Column(String(512), nullable=True, comment="Prompt description") user_name = Column(String(128), index=True, nullable=True, comment="User name") sys_code = Column(String(128), index=True, nullable=True, comment="System code") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") @@ -96,6 +97,7 @@ def to_request(self, entity: ServeEntity) -> ServeRequest: prompt_type=entity.prompt_type, prompt_name=entity.prompt_name, content=entity.content, + prompt_desc=entity.prompt_desc, user_name=entity.user_name, sys_code=entity.sys_code, ) @@ -119,6 +121,7 @@ def to_response(self, entity: ServeEntity) -> ServerResponse: prompt_type=entity.prompt_type, prompt_name=entity.prompt_name, content=entity.content, + prompt_desc=entity.prompt_desc, user_name=entity.user_name, sys_code=entity.sys_code, gmt_created=gmt_created_str, diff --git a/dbgpt/serve/prompt/serve.py b/dbgpt/serve/prompt/serve.py index 015450523..ed50a353f 100644 --- a/dbgpt/serve/prompt/serve.py +++ b/dbgpt/serve/prompt/serve.py @@ -3,10 +3,11 @@ from sqlalchemy import URL -from dbgpt.component import BaseComponent, SystemApp +from dbgpt.component import SystemApp from dbgpt.core import PromptManager -from ...storage.metadata import DatabaseManager +from dbgpt.storage.metadata import DatabaseManager +from dbgpt.serve.core import BaseServe from .api.endpoints import init_endpoints, router from .config import ( APP_NAME, @@ -20,7 +21,7 @@ logger = logging.getLogger(__name__) -class Serve(BaseComponent): +class Serve(BaseServe): """Serve component Examples: @@ -37,6 +38,7 @@ class Serve(BaseComponent): app = FastAPI() system_app = SystemApp(app) system_app.register(Serve, api_prefix="/api/v1/prompt") + system_app.on_init() # Run before start hook system_app.before_start() @@ -61,6 +63,7 @@ class Serve(BaseComponent): app = FastAPI() system_app = SystemApp(app) system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True) + system_app.on_init() # Run before start hook system_app.before_start() @@ -81,31 +84,41 @@ def __init__( self, system_app: SystemApp, api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", - tags: Optional[List[str]] = None, + api_tags: Optional[List[str]] = None, db_url_or_db: Union[str, URL, DatabaseManager] = None, try_create_tables: Optional[bool] = False, ): - if tags is None: - tags = [SERVE_APP_NAME_HUMP] - self._system_app = None - self._api_prefix = api_prefix - self._tags = tags + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) self._prompt_manager = None - self._db_url_or_db = db_url_or_db - self._try_create_tables = try_create_tables + self._db_manager: Optional[DatabaseManager] = None def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return self._system_app = system_app self._system_app.app.include_router( - router, prefix=self._api_prefix, tags=self._tags + router, prefix=self._api_prefix, tags=self._api_tags ) init_endpoints(self._system_app) + self._app_has_initiated = True @property def prompt_manager(self) -> PromptManager: """Get the prompt manager of the serve app with db storage""" return self._prompt_manager + def on_init(self): + """Called before the start of the application. + + You can do some initialization here. + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + def before_start(self): """Called before the start of the application. @@ -113,23 +126,16 @@ def before_start(self): """ # import your own module here to ensure the module is loaded before the application starts from dbgpt.core.interface.prompt import PromptManager - from dbgpt.storage.metadata import Model, db from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage from dbgpt.util.serialization.json_serialization import JsonSerializer from .models.models import ServeEntity - init_db = self._db_url_or_db or db - init_db = DatabaseManager.build_from(init_db, base=Model) - if self._try_create_tables: - try: - init_db.create_all() - except Exception as e: - logger.warning(f"Failed to create tables: {e}") + self._db_manager = self.create_or_get_db_manager() storage_adapter = PromptTemplateAdapter() serializer = JsonSerializer() storage = SQLAlchemyStorage( - init_db, + self._db_manager, ServeEntity, storage_adapter, serializer, diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/serve.py b/dbgpt/serve/utils/_template_files/default_serve_template/serve.py index 1f6bb4811..77b9e477b 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/serve.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/serve.py @@ -1,6 +1,9 @@ -from typing import List, Optional -from dbgpt.component import BaseComponent, SystemApp - +from typing import List, Optional, Union +import logging +from dbgpt.component import SystemApp +from sqlalchemy import URL +from dbgpt.storage.metadata import DatabaseManager +from dbgpt.serve.core import BaseServe from .api.endpoints import router, init_endpoints from .config import ( SERVE_APP_NAME, @@ -10,8 +13,10 @@ ServeConfig, ) +logger = logging.getLogger(__name__) + -class Serve(BaseComponent): +class Serve(BaseServe): """Serve component for DB-GPT""" name = SERVE_APP_NAME @@ -20,25 +25,36 @@ def __init__( self, system_app: SystemApp, api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", - tags: Optional[List[str]] = None, + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, ): - if tags is None: - tags = [SERVE_APP_NAME_HUMP] - self._system_app = None - self._api_prefix = api_prefix - self._tags = tags + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return self._system_app = system_app self._system_app.app.include_router( - router, prefix=self._api_prefix, tags=self._tags + router, prefix=self._api_prefix, tags=self._api_tags ) init_endpoints(self._system_app) + self._app_has_initiated = True - def before_start(self): - """Called before the start of the application. + def on_init(self): + """Called when init the application. - You can do some initialization here. + You can do some initialization here. You can't get other components here because they may be not initialized yet """ # import your own module here to ensure the module is loaded before the application starts from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + # TODO: Your code here + self._db_manager = self.create_or_get_db_manager() diff --git a/dbgpt/storage/metadata/__init__.py b/dbgpt/storage/metadata/__init__.py index 63e58aefe..ee409f6cb 100644 --- a/dbgpt/storage/metadata/__init__.py +++ b/dbgpt/storage/metadata/__init__.py @@ -5,6 +5,7 @@ create_model, BaseModel, ) +from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory from dbgpt.storage.metadata._base_dao import BaseDao __ALL__ = [ @@ -14,4 +15,5 @@ "create_model", "BaseModel", "BaseDao", + "UnifiedDBManagerFactory", ] diff --git a/dbgpt/storage/metadata/db_factory.py b/dbgpt/storage/metadata/db_factory.py new file mode 100644 index 000000000..14cf0339a --- /dev/null +++ b/dbgpt/storage/metadata/db_factory.py @@ -0,0 +1,21 @@ +from dbgpt.component import SystemApp, BaseComponent, ComponentType + +from .db_manager import DatabaseManager + + +class UnifiedDBManagerFactory(BaseComponent): + name = ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY + + def __init__(self, system_app: SystemApp, db_manager: DatabaseManager): + super().__init__(system_app) + self._db_manager = db_manager + + def init_app(self, system_app: SystemApp): + pass + + def create(self) -> DatabaseManager: + if not self._db_manager: + raise RuntimeError("db_manager is not initialized") + if not self._db_manager.is_initialized: + raise RuntimeError("db_manager is not initialized") + return self._db_manager diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py index 0876dd491..be8f0b1ec 100644 --- a/dbgpt/storage/metadata/db_manager.py +++ b/dbgpt/storage/metadata/db_manager.py @@ -183,6 +183,11 @@ def engine(self): """Get the engine.""" "" return self._engine + @property + def is_initialized(self) -> bool: + """Whether the database manager is initialized.""" "" + return self._engine is not None and self._session is not None + @contextmanager def session(self) -> Session: """Get the session with context manager. @@ -200,7 +205,7 @@ def session(self) -> Session: RuntimeError: The database manager is not initialized. Exception: Any exception. """ - if not self._session: + if not self.is_initialized: raise RuntimeError("The database manager is not initialized.") session = self._session() try: diff --git a/dbgpt/util/config_utils.py b/dbgpt/util/config_utils.py index fd431f164..cf721f2c8 100644 --- a/dbgpt/util/config_utils.py +++ b/dbgpt/util/config_utils.py @@ -1,3 +1,4 @@ +import os from functools import cache from typing import Any, Dict, Optional @@ -33,3 +34,19 @@ def get_all_by_prefix(self, prefix) -> Dict[str, Any]: prefix (str): The prefix of config """ return {k: v for k, v in self.configs.items() if k.startswith(prefix)} + + def get_current_lang(self, default: Optional[str] = None) -> str: + """Get current language + + Args: + default (Optional[str], optional): The default language if not found. Defaults to None. + + Returns: + str: The language of user running environment + """ + env_lang = ( + "zh" + if os.getenv("LANG") and os.getenv("LANG").startswith("zh") + else default + ) + return self.get("dbgpt.app.global.language", env_lang) diff --git a/dbgpt/util/openai_utils.py b/dbgpt/util/openai_utils.py index 6577d3abf..5e1673f08 100644 --- a/dbgpt/util/openai_utils.py +++ b/dbgpt/util/openai_utils.py @@ -44,11 +44,20 @@ async def _do_chat_completion( decoded_line = line.split("data: ", 1)[1] if decoded_line.lower().strip() != "[DONE]".lower(): obj = json.loads(decoded_line) - if obj["choices"][0]["delta"].get("content") is not None: - text = obj["choices"][0]["delta"].get("content") + if "error_code" in obj and obj["error_code"] != 0: if caller: - await caller(text) - yield text + await caller(obj.get("text")) + yield obj.get("text") + else: + if ( + "choices" in obj + and obj["choices"][0]["delta"].get("content") + is not None + ): + text = obj["choices"][0]["delta"].get("content") + if caller: + await caller(text) + yield text await asyncio.sleep(0.02) diff --git a/dbgpt/util/tracer/span_storage.py b/dbgpt/util/tracer/span_storage.py index f5d560fed..4872336dc 100644 --- a/dbgpt/util/tracer/span_storage.py +++ b/dbgpt/util/tracer/span_storage.py @@ -139,7 +139,7 @@ def _roll_over_if_needed(self): def _write_to_file(self, spans: List[Span]): self._roll_over_if_needed() - with open(self.filename, "a") as file: + with open(self.filename, "a", encoding="utf8") as file: for span in spans: span_data = span.to_dict() try: diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py index 0d39bc4b6..4f7c2678d 100644 --- a/dbgpt/util/utils.py +++ b/dbgpt/util/utils.py @@ -10,6 +10,14 @@ from dbgpt.configs.model_config import LOGDIR +try: + from termcolor import colored +except ImportError: + + def colored(x, *args, **kwargs): + return x + + server_error_msg = ( "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" ) diff --git a/examples/awel/data_analyst_assistant.py b/examples/awel/data_analyst_assistant.py new file mode 100644 index 000000000..f40a5be9e --- /dev/null +++ b/examples/awel/data_analyst_assistant.py @@ -0,0 +1,438 @@ +"""AWEL: Data analyst assistant. + + DB-GPT will automatically load and execute the current file after startup. + + Examples: + + .. code-block:: shell + + # Run this file in your terminal with dev mode. + # First terminal + export OPENAI_API_KEY=xxx + export OPENAI_API_BASE=https://api.openai.com/v1 + python examples/awel/simple_chat_history_example.py + + + Code fix command, return no streaming response + + .. code-block:: shell + + # Open a new terminal + # Second terminal + + DBGPT_SERVER="http://127.0.0.1:5555" + MODEL="gpt-3.5-turbo" + # Fist round + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \ + -H "Content-Type: application/json" -d '{ + "command": "dbgpt_awel_data_analyst_code_fix", + "model": "gpt-3.5-turbo", + "stream": false, + "context": { + "conv_uid": "uuid_conv_copilot_1234", + "chat_mode": "chat_with_code" + }, + "messages": "SELECT * FRM orders WHERE order_amount > 500;" + }' + +""" +import logging +from functools import cache +from typing import Any, Dict, List, Optional + +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core import ( + InMemoryStorage, + LLMClient, + MessageStorageItem, + ModelMessage, + ModelMessageRoleType, + PromptManager, + PromptTemplate, + StorageConversation, + StorageInterface, +) +from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator +from dbgpt.core.operator import ( + BufferedConversationMapperOperator, + LLMBranchOperator, + LLMOperator, + PostConversationOperator, + PostStreamingConversationOperator, + PreConversationOperator, + RequestBuildOperator, + StreamingLLMOperator, +) +from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator +from dbgpt.util.utils import colored + +logger = logging.getLogger(__name__) + +CODE_FIX = "dbgpt_awel_data_analyst_code_fix" +CODE_PERF = "dbgpt_awel_data_analyst_code_perf" +CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain" +CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment" +CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate" + +CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师, +这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。""" +CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst, +here is a snippet of code of {language}. Please review the code following best practices to identify and fix all errors. +Provide the corrected code and include a line-by-line explanation of all the fixes you've made, please use the same language as the user.""" + +CODE_PERF_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,这里有一段 {language} 代码。 +请你按照最佳实践来优化这段代码。请在代码中加入注释点明所做的更改,并解释每项优化的原因,以便提高代码的维护性和性能,请使用和用户相同的语言进行回答。""" +CODE_PERF_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst, +you are provided with a snippet of code of {language}. Please optimize the code according to best practices. +Include comments to highlight the changes made and explain the reasons for each optimization for better maintenance and performance, +please use the same language as the user.""" +CODE_EXPLAIN_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师, +现在给你的是一份 {language} 代码。请你逐行解释代码的含义,请使用和用户相同的语言进行回答。""" + +CODE_EXPLAIN_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst, +you are provided with a snippet of code of {language}. Please explain the meaning of the code line by line, +please use the same language as the user.""" + +CODE_COMMENT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,现在给你的是一份 {language} 代码。 +请你为每一行代码添加注释,解释每个部分的作用,请使用和用户相同的语言进行回答。""" + +CODE_COMMENT_TEMPLATE_EN = """As an experienced Data Warehouse Developer and Data Analyst. +Below is a snippet of code written in {language}. +Please provide line-by-line comments explaining what each section of the code does, please use the same language as the user.""" + +CODE_TRANSLATE_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,现在手头有一份用{source_language}语言编写的代码片段。 +请你将这段代码准确无误地翻译成{target_language}语言,确保语法和功能在翻译后的代码中得到正确体现,请使用和用户相同的语言进行回答。""" +CODE_TRANSLATE_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst, +you're presented with a snippet of code written in {source_language}. +Please translate this code into {target_language} ensuring that the syntax and functionalities are accurately reflected in the translated code, +please use the same language as the user.""" + + +class ReqContext(BaseModel): + user_name: Optional[str] = Field( + None, description="The user name of the model request." + ) + + sys_code: Optional[str] = Field( + None, description="The system code of the model request." + ) + conv_uid: Optional[str] = Field( + None, description="The conversation uid of the model request." + ) + chat_mode: Optional[str] = Field( + "chat_with_code", description="The chat mode of the model request." + ) + + +class TriggerReqBody(BaseModel): + messages: str = Field(..., description="User input messages") + command: Optional[str] = Field(default="fix", description="Command name") + model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name") + stream: Optional[bool] = Field(default=False, description="Whether return stream") + language: Optional[str] = Field(default="hive", description="Language") + target_language: Optional[str] = Field( + default="hive", description="Target language, use in translate" + ) + context: Optional[ReqContext] = Field( + default=None, description="The context of the model request." + ) + + +@cache +def load_or_save_prompt_template(pm: PromptManager): + ext_params = { + "chat_scene": "chat_with_code", + "sub_chat_scene": "data_analyst", + "prompt_type": "common", + } + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_FIX_TEMPLATE_ZH, + ), + prompt_name=CODE_FIX, + prompt_language="zh", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_FIX_TEMPLATE_EN, + ), + prompt_name=CODE_FIX, + prompt_language="en", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_PERF_TEMPLATE_ZH, + ), + prompt_name=CODE_PERF, + prompt_language="zh", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_PERF_TEMPLATE_EN, + ), + prompt_name=CODE_PERF, + prompt_language="en", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_EXPLAIN_TEMPLATE_ZH, + ), + prompt_name=CODE_EXPLAIN, + prompt_language="zh", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_EXPLAIN_TEMPLATE_EN, + ), + prompt_name=CODE_EXPLAIN, + prompt_language="en", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_COMMENT_TEMPLATE_ZH, + ), + prompt_name=CODE_COMMENT, + prompt_language="zh", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["language"], + template=CODE_COMMENT_TEMPLATE_EN, + ), + prompt_name=CODE_COMMENT, + prompt_language="en", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["source_language", "target_language"], + template=CODE_TRANSLATE_TEMPLATE_ZH, + ), + prompt_name=CODE_TRANSLATE, + prompt_language="zh", + **ext_params, + ) + pm.query_or_save( + PromptTemplate( + input_variables=["source_language", "target_language"], + template=CODE_TRANSLATE_TEMPLATE_EN, + ), + prompt_name=CODE_TRANSLATE, + prompt_language="en", + **ext_params, + ) + + +class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._default_prompt_manager = PromptManager() + + async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: + from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME + from dbgpt.serve.prompt.serve import Serve as PromptServe + + prompt_serve = self.system_app.get_component( + PROMPT_SERVE_APP_NAME, PromptServe, default_component=None + ) + if prompt_serve: + pm = prompt_serve.prompt_manager + else: + pm = self._default_prompt_manager + load_or_save_prompt_template(pm) + + user_language = self.system_app.config.get_current_lang(default="en") + + prompt_list = pm.prefer_query( + input_value.command, prefer_prompt_language=user_language + ) + if not prompt_list: + error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}" + logger.error(error_msg) + raise ValueError(error_msg) + prompt = prompt_list[0].to_prompt_template() + if input_value.command == CODE_TRANSLATE: + format_params = { + "source_language": input_value.language, + "target_language": input_value.target_language, + } + else: + format_params = {"language": input_value.language} + + system_message = prompt.format(**format_params) + messages = [ + ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message), + ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages), + ] + context = input_value.context.dict() if input_value.context else {} + return { + "messages": messages, + "stream": input_value.stream, + "model": input_value.model, + "context": context, + } + + +class MyConversationOperator(PreConversationOperator): + def __init__( + self, + storage: Optional[StorageInterface[StorageConversation, Any]] = None, + message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + **kwargs, + ): + super().__init__(storage, message_storage, **kwargs) + + def _get_conversion_serve(self): + from dbgpt.serve.conversation.serve import ( + SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME, + ) + from dbgpt.serve.conversation.serve import Serve as ConversationServe + + conversation_serve: ConversationServe = self.system_app.get_component( + CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None + ) + return conversation_serve + + @property + def storage(self): + if self._storage: + return self._storage + conversation_serve = self._get_conversion_serve() + if conversation_serve: + return conversation_serve.conv_storage + else: + logger.info("Conversation storage not found, use InMemoryStorage default") + self._storage = InMemoryStorage() + return self._storage + + @property + def message_storage(self): + if self._message_storage: + return self._message_storage + conversation_serve = self._get_conversion_serve() + if conversation_serve: + return conversation_serve.message_storage + else: + logger.info("Message storage not found, use InMemoryStorage default") + self._message_storage = InMemoryStorage() + return self._message_storage + + +class MyLLMOperator(MixinLLMOperator, LLMOperator): + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + LLMOperator.__init__(self, llm_client, **kwargs) + + +class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + StreamingLLMOperator.__init__(self, llm_client, **kwargs) + + +def history_message_mapper( + messages_by_round: List[List[ModelMessage]], +) -> List[ModelMessage]: + """Mapper for history conversation. + + If there are multi system messages, just keep the first system message. + """ + has_system_message = False + mapper_messages = [] + for messages in messages_by_round: + for message in messages: + if message.role == ModelMessageRoleType.SYSTEM: + if has_system_message: + continue + else: + mapper_messages.append(message) + has_system_message = True + else: + mapper_messages.append(message) + print("history_message_mapper start:" + "=" * 70) + print(colored(ModelMessage.get_printable_message(mapper_messages), "green")) + print("history_message_mapper end:" + "=" * 72) + return mapper_messages + + +with DAG("dbgpt_awel_data_analyst_assistant") as dag: + trigger = HttpTrigger( + "/examples/data_analyst/copilot", + request_body=TriggerReqBody, + methods="POST", + streaming_predict_func=lambda x: x.stream, + ) + + copilot_task = CopilotOperator() + request_handle_task = RequestBuildOperator() + + # Pre-process conversation + pre_conversation_task = MyConversationOperator() + # Keep last k round conversation. + history_conversation_task = BufferedConversationMapperOperator( + last_k_round=5, message_mapper=history_message_mapper + ) + + # Save conversation to storage. + post_conversation_task = PostConversationOperator() + # Save streaming conversation to storage. + post_streaming_conversation_task = PostStreamingConversationOperator() + + # Use LLMOperator to generate response. + llm_task = MyLLMOperator(task_name="llm_task") + streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") + branch_task = LLMBranchOperator( + stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" + ) + model_parse_task = MapOperator(lambda out: out.to_dict()) + openai_format_stream_task = OpenAIStreamingOperator() + result_join_task = JoinOperator( + combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out + ) + + ( + trigger + >> copilot_task + >> request_handle_task + >> pre_conversation_task + >> history_conversation_task + >> branch_task + ) + # The branch of no streaming response. + ( + branch_task + >> llm_task + >> post_conversation_task + >> model_parse_task + >> result_join_task + ) + # The branch of streaming response. + ( + branch_task + >> streaming_llm_task + >> post_streaming_conversation_task + >> openai_format_stream_task + >> result_join_task + ) + +if __name__ == "__main__": + if dag.leaf_nodes[0].dev_mode: + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag]) + else: + pass diff --git a/examples/notebook/agent_auto_plan_dialogue_example.ipynb b/examples/notebook/agent_auto_plan_dialogue_example.ipynb index db1c4d16d..201980a99 100644 --- a/examples/notebook/agent_auto_plan_dialogue_example.ipynb +++ b/examples/notebook/agent_auto_plan_dialogue_example.ipynb @@ -33,12 +33,15 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, + "execution_count": null, + "metadata": { + "is_executing": true + }, "outputs": [], "source": [ - "os.environ[\"OPENAI_API_KEY\"] = \"sk-ElhG3036tcvECTOYO9nHqAkUIWtqT55JXs1cTjODz1bnUQDz\"\n", - "os.environ[\"OPENAI_API_BASE\"] = \"https://api.chatanywhere.tech/v1\"" + "# Set your api key and api base url\n", + "# os.environ[\"OPENAI_API_KEY\"] = \"Your API\"\n", + "# os.environ[\"OPENAI_API_BASE\"] = \"https://api.openai.com/v1\"" ] }, { @@ -50,12 +53,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mUser\u001b[0m (to plan_manager)-[]:\n", + "\u001B[33mUser\u001B[0m (to plan_manager)-[]:\n", "\n", "\"Obtain simple information about issues in the repository 'eosphoros-ai/DB-GPT' in the past three days and analyze the data. Create a Markdown table grouped by day and status.\"\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[33mplan_manager\u001b[0m (to Planner)-[]:\n", + "\u001B[33mplan_manager\u001B[0m (to Planner)-[]:\n", "\n", "\"Obtain simple information about issues in the repository 'eosphoros-ai/DB-GPT' in the past three days and analyze the data. Create a Markdown table grouped by day and status.\"\n", "\n", @@ -72,15 +75,15 @@ " \"content\": \"Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\",\n", " \"rely\": \"1\"\n", "}]\n", - "\u001b[33mPlanner\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n", + "\u001B[33mPlanner\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n", "\n", "\"[{\\n \\\"serial_number\\\": \\\"1\\\",\\n \\\"agent\\\": \\\"CodeEngineer\\\",\\n \\\"content\\\": \\\"Write a Python script to use the GitHub API to retrieve issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. The script should extract the issue's creation date and status.\\\",\\n \\\"rely\\\": \\\"\\\"\\n},\\n{\\n \\\"serial_number\\\": \\\"2\\\",\\n \\\"agent\\\": \\\"CodeEngineer\\\",\\n \\\"content\\\": \\\"Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\\\",\\n \\\"rely\\\": \\\"1\\\"\\n}]\"\n", - "\u001b[32m>>>>>>>>Planner Review info: \n", - " Pass.None\u001b[0m\n", - "\u001b[34m>>>>>>>>Planner Action report: \n", + "\u001B[32m>>>>>>>>Planner Review info: \n", + " Pass.None\u001B[0m\n", + "\u001B[34m>>>>>>>>Planner Action report: \n", "execution succeeded,\n", "1,Write a Python script to use the GitHub API to retrieve issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. The script should extract the issue's creation date and status.\n", - "2,Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\u001b[0m\n", + "2,Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\u001B[0m\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -96,7 +99,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n", + "\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n", "\n", "\"Write a Python script to use the GitHub API to retrieve issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. The script should extract the issue's creation date and status.\"\n", "\n", @@ -141,14 +144,14 @@ "```\n", "\n", "When you run this script, it will print out the creation date and status of issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. Make sure to have internet access while running the script as it needs to make requests to the GitHub API.\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", - "\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n", + "\u001B[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n", + "\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n", "\n", "\"```python\\n# filename: github_issues.py\\nimport requests\\nfrom datetime import datetime, timedelta\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n \\n # Parameters to get issues created in the past three days\\n params = {\\n 'since': since,\\n 'state': 'all', # Get both open and closed issues\\n }\\n \\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n \\n # Parse the JSON response\\n issues = response.json()\\n \\n # Extract and print the creation date and status of each issue\\n for issue in issues:\\n created_at = issue['created_at']\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nWhen you run this script, it will print out the creation date and status of issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. Make sure to have internet access while running the script as it needs to make requests to the GitHub API.\"\n", - "\u001b[32m>>>>>>>>CodeEngineer Review info: \n", - " Pass.None\u001b[0m\n", - "\u001b[34m>>>>>>>>CodeEngineer Action report: \n", + "\u001B[32m>>>>>>>>CodeEngineer Review info: \n", + " Pass.None\u001B[0m\n", + "\u001B[34m>>>>>>>>CodeEngineer Action report: \n", "execution failed,\n", "exitcode: 1 (execution failed)\n", " \n", @@ -156,7 +159,7 @@ " File \"github_issues.py\", line 2, in \n", " import requests\n", "ModuleNotFoundError: No module named 'requests'\n", - "\u001b[0m\n", + "\u001B[0m\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -172,7 +175,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n", + "\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n", "\n", "\"exitcode: 1 (execution failed)\\n \\nTraceback (most recent call last):\\n File \\\"github_issues.py\\\", line 2, in \\n import requests\\nModuleNotFoundError: No module named 'requests'\\n\"\n", "\n", @@ -225,17 +228,17 @@ "```\n", "\n", "Please run the above Python script after you have installed the `requests` module.\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is sh)...\u001b[0m\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 1 (inferred language is python)...\u001b[0m\n", + "\u001B[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is sh)...\u001B[0m\n", + "\u001B[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 1 (inferred language is python)...\u001B[0m\n", "un_stream ai response: False\n", - "\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n", + "\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n", "\n", "\"It appears that the `requests` module is not installed on your system. You can install it using `pip`, the Python package manager. Here's the command to install the `requests` module:\\n\\n```sh\\npip install requests\\n```\\n\\nAfter installing the `requests` module, you can try running the script again. Here is the complete script including the installation of the `requests` module:\\n\\n```python\\n# filename: github_issues.py\\nimport requests\\nfrom datetime import datetime, timedelta\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n \\n # Parameters to get issues created in the past three days\\n params = {\\n 'since': since,\\n 'state': 'all', # Get both open and closed issues\\n }\\n \\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n \\n # Parse the JSON response\\n issues = response.json()\\n \\n # Extract and print the creation date and status of each issue\\n for issue in issues:\\n created_at = issue['created_at']\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nPlease run the above Python script after you have installed the `requests` module.\"\n", - "\u001b[32m>>>>>>>>CodeEngineer Review info: \n", - " Pass.None\u001b[0m\n", - "\u001b[34m>>>>>>>>CodeEngineer Action report: \n", + "\u001B[32m>>>>>>>>CodeEngineer Review info: \n", + " Pass.None\u001B[0m\n", + "\u001B[34m>>>>>>>>CodeEngineer Action report: \n", "execution succeeded,\n", "\n", "Collecting requests\n", @@ -302,7 +305,7 @@ "Issue created at: 2023-10-26T01:45:37Z, Status: open\n", "Issue created at: 2023-08-30T06:02:32Z, Status: closed\n", "Issue created at: 2023-06-18T01:11:22Z, Status: open\n", - "\u001b[0m\n", + "\u001B[0m\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -318,7 +321,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n", + "\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n", "\n", "\"The execution result of the code you wrote is judged as not answering the task question. Please re-understand and complete the task.\"\n", "\n", @@ -376,15 +379,15 @@ "```\n", "\n", "When you run this script, it will check for the `requests` module and install it if necessary. Then it will retrieve the issues from the 'eosphoros-ai/DB-GPT' repository created in the past three days and print their creation date and status.\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", + "\u001B[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n", "un_stream ai response: False\n", - "\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n", + "\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n", "\n", "\"I apologize for the inconvenience. Let me provide you with a Python script that includes a check to ensure the `requests` module is installed. If it's not installed, the script will attempt to install it using `pip`. Please run the following Python script:\\n\\n```python\\n# filename: github_issues.py\\nimport subprocess\\nimport pkg_resources\\nfrom datetime import datetime, timedelta\\n\\n# Ensure the 'requests' package is installed\\nrequired_packages = {'requests'}\\ninstalled_packages = {pkg.key for pkg in pkg_resources.working_set}\\nmissing_packages = required_packages - installed_packages\\n\\nif missing_packages:\\n subprocess.check_call(['pip', 'install', *missing_packages])\\n\\nimport requests\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n \\n # Parameters to get issues created in the past three days\\n params = {\\n 'since': since,\\n 'state': 'all', # Get both open and closed issues\\n }\\n \\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n \\n # Parse the JSON response\\n issues = response.json()\\n \\n # Extract and print the creation date and status of each issue\\n for issue in issues:\\n created_at = issue['created_at']\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nWhen you run this script, it will check for the `requests` module and install it if necessary. Then it will retrieve the issues from the 'eosphoros-ai/DB-GPT' repository created in the past three days and print their creation date and status.\"\n", - "\u001b[32m>>>>>>>>CodeEngineer Review info: \n", - " Pass.None\u001b[0m\n", - "\u001b[34m>>>>>>>>CodeEngineer Action report: \n", + "\u001B[32m>>>>>>>>CodeEngineer Review info: \n", + " Pass.None\u001B[0m\n", + "\u001B[34m>>>>>>>>CodeEngineer Action report: \n", "execution succeeded,\n", "\n", "/workspace/github_issues.py:3: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n", @@ -421,7 +424,7 @@ "Issue created at: 2023-10-26T01:45:37Z, Status: open\n", "Issue created at: 2023-08-30T06:02:32Z, Status: closed\n", "Issue created at: 2023-06-18T01:11:22Z, Status: open\n", - "\u001b[0m\n", + "\u001B[0m\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -437,7 +440,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n", + "\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n", "\n", "\"The execution result of the code you wrote is judged as not answering the task question. Please re-understand and complete the task.\"\n", "\n", @@ -484,15 +487,15 @@ "```\n", "\n", "This script retrieves all issues from the specified repository and then filters them based on the creation date to find the ones that were created in the past three days. It then prints the creation date and status of each issue that meets the criteria. Please execute this updated script.\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", + "\u001B[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n", "un_stream ai response: True\n", - "\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n", + "\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n", "\n", "\"I apologize for the oversight. Let's try a different approach to ensure we get the issues created in the past three days from the 'eosphoros-ai/DB-GPT' repository. I will provide you with an updated Python script that correctly retrieves and filters the issues based on their creation date.\\n\\n```python\\n# filename: github_issues.py\\nimport requests\\nfrom datetime import datetime, timedelta\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n\\n # Parameters to get all issues, GitHub API does not allow filtering by creation date directly\\n params = {\\n 'state': 'all', # Get both open and closed issues\\n }\\n\\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n\\n # Parse the JSON response\\n issues = response.json()\\n\\n # Filter and print the creation date and status of each issue created in the past three days\\n for issue in issues:\\n created_at = datetime.strptime(issue['created_at'], \\\"%Y-%m-%dT%H:%M:%SZ\\\")\\n if created_at > three_days_ago:\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nThis script retrieves all issues from the specified repository and then filters them based on the creation date to find the ones that were created in the past three days. It then prints the creation date and status of each issue that meets the criteria. Please execute this updated script.\"\n", - "\u001b[32m>>>>>>>>CodeEngineer Review info: \n", - " Pass.None\u001b[0m\n", - "\u001b[34m>>>>>>>>CodeEngineer Action report: \n", + "\u001B[32m>>>>>>>>CodeEngineer Review info: \n", + " Pass.None\u001B[0m\n", + "\u001B[34m>>>>>>>>CodeEngineer Action report: \n", "execution succeeded,\n", "\n", "/workspace/github_issues.py:8: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", @@ -510,7 +513,7 @@ "Issue created at: 2023-12-25 08:35:48, Status: open\n", "Issue created at: 2023-12-25 08:34:13, Status: closed\n", "Issue created at: 2023-12-25 06:33:37, Status: closed\n", - "\u001b[0m\n", + "\u001B[0m\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -526,7 +529,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n", + "\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n", "\n", "\"Read the result data of the dependent steps in the above historical message to complete the current goal:Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\"\n", "\n", @@ -573,15 +576,15 @@ "```\n", "\n", "When you run this script, it will print out a Markdown table with the number of open and closed issues grouped by their creation dates.\n", - "\u001b[31m\n", - ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", + "\u001B[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n", "un_stream ai response: True\n", - "\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n", + "\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n", "\n", "\"```python\\n# Let's process the retrieved issues data to group them by creation date and status.\\n# Then, we will format the grouped data into a Markdown table.\\n\\n# The data from the previous step\\nissues_data = [\\n {\\\"created_at\\\": \\\"2023-12-27 03:20:12\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 16:13:42\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 10:23:43\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 07:25:37\\\", \\\"status\\\": \\\"closed\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 05:01:26\\\", \\\"status\\\": \\\"closed\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 03:54:15\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 02:19:48\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 12:26:40\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 12:00:52\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 08:36:56\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 08:35:48\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 08:34:13\\\", \\\"status\\\": \\\"closed\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 06:33:37\\\", \\\"status\\\": \\\"closed\\\"},\\n]\\n\\n# Grouping the issues by creation date and status\\nfrom collections import defaultdict\\ngrouped_issues = defaultdict(lambda: defaultdict(list))\\n\\nfor issue in issues_data:\\n date = issue[\\\"created_at\\\"].split(' ')[0] # Extract just the date part\\n grouped_issues[date][issue[\\\"status\\\"]].append(issue)\\n\\n# Format the grouped data into a Markdown table\\nmarkdown_table = \\\"Date | Open Issues | Closed Issues\\\\n\\\"\\nmarkdown_table += \\\"--- | --- | ---\\\\n\\\"\\n\\nfor date, statuses in sorted(grouped_issues.items()):\\n open_issues = len(statuses[\\\"open\\\"]) if \\\"open\\\" in statuses else 0\\n closed_issues = len(statuses[\\\"closed\\\"]) if \\\"closed\\\" in statuses else 0\\n markdown_table += f\\\"{date} | {open_issues} | {closed_issues}\\\\n\\\"\\n\\nprint(markdown_table)\\n```\\n\\nWhen you run this script, it will print out a Markdown table with the number of open and closed issues grouped by their creation dates.\"\n", - "\u001b[32m>>>>>>>>CodeEngineer Review info: \n", - " Pass.None\u001b[0m\n", - "\u001b[34m>>>>>>>>CodeEngineer Action report: \n", + "\u001B[32m>>>>>>>>CodeEngineer Review info: \n", + " Pass.None\u001B[0m\n", + "\u001B[34m>>>>>>>>CodeEngineer Action report: \n", "execution succeeded,\n", "\n", "Date | Open Issues | Closed Issues\n", @@ -590,17 +593,17 @@ "2023-12-26 | 4 | 2\n", "2023-12-27 | 1 | 0\n", "\n", - "\u001b[0m\n", + "\u001B[0m\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[33mplan_manager\u001b[0m (to User)-[None]:\n", + "\u001B[33mplan_manager\u001B[0m (to User)-[None]:\n", "\n", "\"Obtain simple information about issues in the repository 'eosphoros-ai/DB-GPT' in the past three days and analyze the data. Create a Markdown table grouped by day and status.\"\n", - "\u001b[32m>>>>>>>>plan_manager Review info: \n", - " Pass.None\u001b[0m\n", - "\u001b[34m>>>>>>>>plan_manager Action report: \n", + "\u001B[32m>>>>>>>>plan_manager Review info: \n", + " Pass.None\u001B[0m\n", + "\u001B[34m>>>>>>>>plan_manager Action report: \n", "execution succeeded,\n", - "TERMINATE\u001b[0m\n", + "TERMINATE\u001B[0m\n", "\n", "--------------------------------------------------------------------------------\n", "```agent-messages\n", diff --git a/setup.py b/setup.py index 0d219f497..fafd2da51 100644 --- a/setup.py +++ b/setup.py @@ -369,6 +369,7 @@ def core_requires(): setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [ "pydantic<2,>=1", "httpx", + "jinja2", "fastapi==0.98.0", "shortuuid", # change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2