From aed1c3fb2ba2892db7efb5109e43e548b61fb0f2 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 15 Dec 2023 16:35:45 +0800 Subject: [PATCH] refactor: Refactor storage system (#937) --- assets/schema/knowledge_management.sql | 408 +++++++------ dbgpt/_private/config.py | 1 + dbgpt/agent/db/my_plugin_db.py | 32 +- dbgpt/agent/db/plugin_hub_db.py | 32 +- dbgpt/agent/hub/agent_hub.py | 40 +- dbgpt/app/_cli.py | 240 ++++++++ dbgpt/app/base.py | 47 +- dbgpt/app/component_configs.py | 1 - dbgpt/app/knowledge/chunk_db.py | 26 +- dbgpt/app/knowledge/document_db.py | 32 +- dbgpt/app/knowledge/space_db.py | 26 +- .../openapi/api_v1/feedback/feed_back_db.py | 23 +- dbgpt/app/prompt/prompt_manage_db.py | 27 +- dbgpt/app/scene/base_chat.py | 14 +- dbgpt/app/scene/operator/_experimental.py | 4 +- dbgpt/cli/cli_scripts.py | 10 + dbgpt/core/__init__.py | 26 + dbgpt/core/awel/dag/tests/test_dag.py | 26 +- dbgpt/core/interface/message.py | 542 ++++++++++++++++-- dbgpt/core/interface/output_parser.py | 32 +- dbgpt/core/interface/serialization.py | 21 +- dbgpt/core/interface/storage.py | 409 +++++++++++++ dbgpt/core/interface/tests/__init__.py | 0 dbgpt/core/interface/tests/conftest.py | 14 + dbgpt/core/interface/tests/test_message.py | 307 ++++++++++ dbgpt/core/interface/tests/test_storage.py | 129 +++++ dbgpt/datasource/base.py | 4 + dbgpt/datasource/manages/connect_config_db.py | 36 +- dbgpt/datasource/rdbms/base.py | 54 +- dbgpt/datasource/rdbms/conn_sqlite.py | 7 +- .../rdbms/tests/test_conn_sqlite.py | 36 +- .../model/cluster/apiserver/tests/test_api.py | 25 +- dbgpt/storage/cache/llm_cache.py | 22 +- dbgpt/storage/chat_history/chat_history_db.py | 79 +-- dbgpt/storage/chat_history/storage_adapter.py | 116 ++++ .../chat_history/store_type/duckdb_history.py | 2 +- .../store_type/meta_db_history.py | 4 +- dbgpt/storage/chat_history/tests/__init__.py | 0 .../tests/test_storage_adapter.py | 219 +++++++ dbgpt/storage/metadata/__init__.py | 16 + dbgpt/storage/metadata/_base_dao.py | 77 ++- dbgpt/storage/metadata/db_manager.py | 432 ++++++++++++++ dbgpt/storage/metadata/db_storage.py | 128 +++++ dbgpt/storage/metadata/meta_data.py | 94 --- dbgpt/storage/metadata/tests/__init__.py | 0 .../storage/metadata/tests/test_db_manager.py | 129 +++++ .../metadata/tests/test_sqlalchemy_storage.py | 173 ++++++ dbgpt/util/_db_migration_utils.py | 219 +++++++ dbgpt/util/annotations.py | 25 + dbgpt/util/pagination_utils.py | 14 + .../util/serialization/json_serialization.py | 4 +- dbgpt/util/string_utils.py | 12 +- docs/docs/faq/install.md | 63 ++ docs/docs/installation/sourcecode.md | 2 + pilot/meta_data/alembic/env.py | 15 +- 55 files changed, 3788 insertions(+), 688 deletions(-) create mode 100644 dbgpt/core/interface/storage.py create mode 100644 dbgpt/core/interface/tests/__init__.py create mode 100644 dbgpt/core/interface/tests/conftest.py create mode 100644 dbgpt/core/interface/tests/test_message.py create mode 100644 dbgpt/core/interface/tests/test_storage.py create mode 100644 dbgpt/storage/chat_history/storage_adapter.py create mode 100644 dbgpt/storage/chat_history/tests/__init__.py create mode 100644 dbgpt/storage/chat_history/tests/test_storage_adapter.py create mode 100644 dbgpt/storage/metadata/db_manager.py create mode 100644 dbgpt/storage/metadata/db_storage.py delete mode 100644 dbgpt/storage/metadata/meta_data.py create mode 100644 dbgpt/storage/metadata/tests/__init__.py create mode 100644 dbgpt/storage/metadata/tests/test_db_manager.py create mode 100644 dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py create mode 100644 dbgpt/util/_db_migration_utils.py create mode 100644 dbgpt/util/pagination_utils.py diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql index eb292b358..53f6747dc 100644 --- a/assets/schema/knowledge_management.sql +++ b/assets/schema/knowledge_management.sql @@ -1,188 +1,238 @@ -- You can change `dbgpt` to your actual metadata database name in your `.env` file -- eg. `LOCAL_DB_NAME=dbgpt` -CREATE DATABASE IF NOT EXISTS dbgpt; +CREATE +DATABASE IF NOT EXISTS dbgpt; use dbgpt; -- For alembic migration tool -CREATE TABLE `alembic_version` ( - version_num VARCHAR(32) NOT NULL, - CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) +CREATE TABLE IF NOT EXISTS `alembic_version` +( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) ); -CREATE TABLE `knowledge_space` ( - `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', - `name` varchar(100) NOT NULL COMMENT 'knowledge space name', - `vector_type` varchar(50) NOT NULL COMMENT 'vector type', - `desc` varchar(500) NOT NULL COMMENT 'description', - `owner` varchar(100) DEFAULT NULL COMMENT 'owner', - `context` TEXT DEFAULT NULL COMMENT 'context argument', - `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', - `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', - PRIMARY KEY (`id`), - KEY `idx_name` (`name`) COMMENT 'index:idx_name' -) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; - -CREATE TABLE `knowledge_document` ( - `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', - `doc_name` varchar(100) NOT NULL COMMENT 'document path name', - `doc_type` varchar(50) NOT NULL COMMENT 'doc type', - `space` varchar(50) NOT NULL COMMENT 'knowledge space', - `chunk_size` int NOT NULL COMMENT 'chunk size', - `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', - `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', - `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', - `result` TEXT NULL COMMENT 'knowledge content', - `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', - `summary` LONGTEXT NULL COMMENT 'knowledge summary', - `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', - `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', - PRIMARY KEY (`id`), - KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' -) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; - -CREATE TABLE `document_chunk` ( - `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', - `doc_name` varchar(100) NOT NULL COMMENT 'document path name', - `doc_type` varchar(50) NOT NULL COMMENT 'doc type', - `document_id` int NOT NULL COMMENT 'document parent id', - `content` longtext NOT NULL COMMENT 'chunk content', - `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', - `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', - `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', - PRIMARY KEY (`id`), - KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' -) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; - - - -CREATE TABLE `connect_config` ( - `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', - `db_type` varchar(255) NOT NULL COMMENT 'db type', - `db_name` varchar(255) NOT NULL COMMENT 'db name', - `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', - `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', - `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', - `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', - `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', - `comment` text COMMENT 'db comment', - `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', - PRIMARY KEY (`id`), - UNIQUE KEY `uk_db` (`db_name`), - KEY `idx_q_db_type` (`db_type`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; - -CREATE TABLE `chat_history` ( - `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', - `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', - `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', - `summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', - `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', - `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', - `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', - PRIMARY KEY (`id`) -) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; - -CREATE TABLE `chat_feed_back` ( - `id` bigint(20) NOT NULL AUTO_INCREMENT, - `conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', - `conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', - `score` int(1) DEFAULT NULL COMMENT 'Score of user', - `ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', - `question` longtext DEFAULT NULL COMMENT 'User question', - `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', - `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', - `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', - `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', - `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', - PRIMARY KEY (`id`), - UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), - KEY `idx_conv` (`conv_uid`,`conv_index`) -) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; - - -CREATE TABLE `my_plugin` ( - `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', - `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', - `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', - `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', - `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', - `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', - `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', - `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', - `use_count` int DEFAULT NULL COMMENT 'plugin total use count', - `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', - `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', - `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', - PRIMARY KEY (`id`), - UNIQUE KEY `name` (`name`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; - -CREATE TABLE `plugin_hub` ( - `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', - `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', - `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', - `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', - `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', - `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', - `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', - `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', - `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', - `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', - `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', - `installed` int DEFAULT NULL COMMENT 'plugin already installed count', - PRIMARY KEY (`id`), - UNIQUE KEY `name` (`name`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; - - -CREATE TABLE `prompt_manage` ( - `id` int(11) NOT NULL AUTO_INCREMENT, - `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', - `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', - `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', - `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', - `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', - `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', - `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', - `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', - `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', - PRIMARY KEY (`id`), - UNIQUE KEY `prompt_name_uiq` (`prompt_name`), - KEY `gmt_created_idx` (`gmt_created`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; - - - -CREATE DATABASE EXAMPLE_1; +CREATE TABLE IF NOT EXISTS `knowledge_space` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `context` TEXT DEFAULT NULL COMMENT 'context argument', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; + +CREATE TABLE IF NOT EXISTS `knowledge_document` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; + +CREATE TABLE IF NOT EXISTS `document_chunk` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; + + + +CREATE TABLE IF NOT EXISTS `connect_config` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `db_type` varchar(255) NOT NULL COMMENT 'db type', + `db_name` varchar(255) NOT NULL COMMENT 'db name', + `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', + `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', + `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', + `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', + `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', + `comment` text COMMENT 'db comment', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_db` (`db_name`), + KEY `idx_q_db_type` (`db_type`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; + +CREATE TABLE IF NOT EXISTS `chat_history` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', + `summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', + `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', + `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `conv_uid` (`conv_uid`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; + +CREATE TABLE IF NOT EXISTS `chat_history_message` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `index` int NOT NULL COMMENT 'Message index', + `round_index` int NOT NULL COMMENT 'Round of conversation', + `message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `message_uid_index` (`conv_uid`, `index`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message'; + +CREATE TABLE IF NOT EXISTS `chat_feed_back` +( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', + `conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', + `score` int(1) DEFAULT NULL COMMENT 'Score of user', + `ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', + `question` longtext DEFAULT NULL COMMENT 'User question', + `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', + `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), + KEY `idx_conv` (`conv_uid`,`conv_index`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; + + +CREATE TABLE IF NOT EXISTS `my_plugin` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', + `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `use_count` int DEFAULT NULL COMMENT 'plugin total use count', + `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; + +CREATE TABLE IF NOT EXISTS `plugin_hub` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', + `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', + `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', + `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', + `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', + `installed` int DEFAULT NULL COMMENT 'plugin already installed count', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; + + +CREATE TABLE IF NOT EXISTS `prompt_manage` +( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; + + + +CREATE +DATABASE IF NOT EXISTS EXAMPLE_1; use EXAMPLE_1; -CREATE TABLE `users` ( - `id` int NOT NULL AUTO_INCREMENT, - `username` varchar(50) NOT NULL COMMENT '用户名', - `password` varchar(50) NOT NULL COMMENT '密码', - `email` varchar(50) NOT NULL COMMENT '邮箱', - `phone` varchar(20) DEFAULT NULL COMMENT '电话', - PRIMARY KEY (`id`), - KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' -) ENGINE=InnoDB AUTO_INCREMENT=101 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; - -INSERT INTO users (username, password, email, phone) VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); -INSERT INTO users (username, password, email, phone) VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); -INSERT INTO users (username, password, email, phone) VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); -INSERT INTO users (username, password, email, phone) VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); -INSERT INTO users (username, password, email, phone) VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); -INSERT INTO users (username, password, email, phone) VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); -INSERT INTO users (username, password, email, phone) VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); -INSERT INTO users (username, password, email, phone) VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); -INSERT INTO users (username, password, email, phone) VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); -INSERT INTO users (username, password, email, phone) VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); -INSERT INTO users (username, password, email, phone) VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); -INSERT INTO users (username, password, email, phone) VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); -INSERT INTO users (username, password, email, phone) VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); -INSERT INTO users (username, password, email, phone) VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); -INSERT INTO users (username, password, email, phone) VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); -INSERT INTO users (username, password, email, phone) VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); -INSERT INTO users (username, password, email, phone) VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); -INSERT INTO users (username, password, email, phone) VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); -INSERT INTO users (username, password, email, phone) VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); -INSERT INTO users (username, password, email, phone) VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); \ No newline at end of file +CREATE TABLE IF NOT EXISTS `users` +( + `id` int NOT NULL AUTO_INCREMENT, + `username` varchar(50) NOT NULL COMMENT '用户名', + `password` varchar(50) NOT NULL COMMENT '密码', + `email` varchar(50) NOT NULL COMMENT '邮箱', + `phone` varchar(20) DEFAULT NULL COMMENT '电话', + PRIMARY KEY (`id`), + KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; + +INSERT INTO users (username, password, email, phone) +VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); \ No newline at end of file diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index adab63b1b..4a8fbb8c5 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -182,6 +182,7 @@ def __init__(self) -> None: self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10)) + self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20)) self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db") diff --git a/dbgpt/agent/db/my_plugin_db.py b/dbgpt/agent/db/my_plugin_db.py index e62896245..ed5ab176e 100644 --- a/dbgpt/agent/db/my_plugin_db.py +++ b/dbgpt/agent/db/my_plugin_db.py @@ -2,16 +2,10 @@ from sqlalchemy import Column, Integer, String, DateTime, func from sqlalchemy import UniqueConstraint -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model -class MyPluginEntity(Base): +class MyPluginEntity(Model): __tablename__ = "my_plugin" __table_args__ = { "mysql_charset": "utf8mb4", @@ -39,16 +33,8 @@ class MyPluginEntity(Base): class MyPluginDao(BaseDao[MyPluginEntity]): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def add(self, engity: MyPluginEntity): - session = self.get_session() + session = self.get_raw_session() my_plugin = MyPluginEntity( tenant=engity.tenant, user_code=engity.user_code, @@ -68,13 +54,13 @@ def add(self, engity: MyPluginEntity): return id def update(self, entity: MyPluginEntity): - session = self.get_session() + session = self.get_raw_session() updated = session.merge(entity) session.commit() return updated.id def get_by_user(self, user: str) -> list[MyPluginEntity]: - session = self.get_session() + session = self.get_raw_session() my_plugins = session.query(MyPluginEntity) if user: my_plugins = my_plugins.filter(MyPluginEntity.user_code == user) @@ -83,7 +69,7 @@ def get_by_user(self, user: str) -> list[MyPluginEntity]: return result def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity: - session = self.get_session() + session = self.get_raw_session() my_plugins = session.query(MyPluginEntity) if user: my_plugins = my_plugins.filter(MyPluginEntity.user_code == user) @@ -93,7 +79,7 @@ def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity: return result def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]: - session = self.get_session() + session = self.get_raw_session() my_plugins = session.query(MyPluginEntity) all_count = my_plugins.count() if query.id is not None: @@ -122,7 +108,7 @@ def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEnti return result, total_pages, all_count def count(self, query: MyPluginEntity): - session = self.get_session() + session = self.get_raw_session() my_plugins = session.query(func.count(MyPluginEntity.id)) if query.id is not None: my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) @@ -143,7 +129,7 @@ def count(self, query: MyPluginEntity): return count def delete(self, plugin_id: int): - session = self.get_session() + session = self.get_raw_session() if plugin_id is None: raise Exception("plugin_id is None") query = MyPluginEntity(id=plugin_id) diff --git a/dbgpt/agent/db/plugin_hub_db.py b/dbgpt/agent/db/plugin_hub_db.py index 07acf2564..d374d284d 100644 --- a/dbgpt/agent/db/plugin_hub_db.py +++ b/dbgpt/agent/db/plugin_hub_db.py @@ -3,19 +3,13 @@ from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL from sqlalchemy import UniqueConstraint -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model # TODO We should consider that the production environment does not have permission to execute the DDL char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4") -class PluginHubEntity(Base): +class PluginHubEntity(Model): __tablename__ = "plugin_hub" __table_args__ = { "mysql_charset": "utf8mb4", @@ -43,16 +37,8 @@ class PluginHubEntity(Base): class PluginHubDao(BaseDao[PluginHubEntity]): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def add(self, engity: PluginHubEntity): - session = self.get_session() + session = self.get_raw_session() timezone = pytz.timezone("Asia/Shanghai") plugin_hub = PluginHubEntity( name=engity.name, @@ -71,7 +57,7 @@ def add(self, engity: PluginHubEntity): return id def update(self, entity: PluginHubEntity): - session = self.get_session() + session = self.get_raw_session() try: updated = session.merge(entity) session.commit() @@ -82,7 +68,7 @@ def update(self, entity: PluginHubEntity): def list( self, query: PluginHubEntity, page=1, page_size=20 ) -> list[PluginHubEntity]: - session = self.get_session() + session = self.get_raw_session() plugin_hubs = session.query(PluginHubEntity) all_count = plugin_hubs.count() @@ -111,7 +97,7 @@ def list( return result, total_pages, all_count def get_by_storage_url(self, storage_url): - session = self.get_session() + session = self.get_raw_session() plugin_hubs = session.query(PluginHubEntity) plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url) result = plugin_hubs.all() @@ -119,7 +105,7 @@ def get_by_storage_url(self, storage_url): return result def get_by_name(self, name: str) -> PluginHubEntity: - session = self.get_session() + session = self.get_raw_session() plugin_hubs = session.query(PluginHubEntity) plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name) result = plugin_hubs.first() @@ -127,7 +113,7 @@ def get_by_name(self, name: str) -> PluginHubEntity: return result def count(self, query: PluginHubEntity): - session = self.get_session() + session = self.get_raw_session() plugin_hubs = session.query(func.count(PluginHubEntity.id)) if query.id is not None: plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) @@ -146,7 +132,7 @@ def count(self, query: PluginHubEntity): return count def delete(self, plugin_id: int): - session = self.get_session() + session = self.get_raw_session() if plugin_id is None: raise Exception("plugin_id is None") plugin_hubs = session.query(PluginHubEntity) diff --git a/dbgpt/agent/hub/agent_hub.py b/dbgpt/agent/hub/agent_hub.py index 470b4ab71..f95b49305 100644 --- a/dbgpt/agent/hub/agent_hub.py +++ b/dbgpt/agent/hub/agent_hub.py @@ -59,18 +59,12 @@ def install_plugin(self, plugin_name: str, user_name: str = None): else: my_plugin_entity.user_code = Default_User - with self.hub_dao.get_session() as session: - try: - if my_plugin_entity.id is None: - session.add(my_plugin_entity) - else: - session.merge(my_plugin_entity) - session.merge(plugin_entity) - session.commit() - session.close() - except Exception as e: - logger.error("install merge roll back!" + str(e)) - session.rollback() + with self.hub_dao.session() as session: + if my_plugin_entity.id is None: + session.add(my_plugin_entity) + else: + session.merge(my_plugin_entity) + session.merge(plugin_entity) except Exception as e: logger.error("install pluguin exception!", e) raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}") @@ -87,19 +81,15 @@ def uninstall_plugin(self, plugin_name, user): my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name) if plugin_entity is not None: plugin_entity.installed = plugin_entity.installed - 1 - with self.hub_dao.get_session() as session: - try: - my_plugin_q = session.query(MyPluginEntity).filter( - MyPluginEntity.name == plugin_name - ) - if user: - my_plugin_q.filter(MyPluginEntity.user_code == user) - my_plugin_q.delete() - if plugin_entity is not None: - session.merge(plugin_entity) - session.commit() - except: - session.rollback() + with self.hub_dao.session() as session: + my_plugin_q = session.query(MyPluginEntity).filter( + MyPluginEntity.name == plugin_name + ) + if user: + my_plugin_q.filter(MyPluginEntity.user_code == user) + my_plugin_q.delete() + if plugin_entity is not None: + session.merge(plugin_entity) if plugin_entity is not None: # delete package file if not use diff --git a/dbgpt/app/_cli.py b/dbgpt/app/_cli.py index acff70749..02dc471b7 100644 --- a/dbgpt/app/_cli.py +++ b/dbgpt/app/_cli.py @@ -1,5 +1,7 @@ +from typing import Optional import click import os +import functools from dbgpt.app.base import WebServerParameters from dbgpt.configs.model_config import LOGDIR from dbgpt.util.parameter_utils import EnvArgumentParser @@ -34,3 +36,241 @@ def stop_webserver(port: int): def _stop_all_dbgpt_server(): _stop_service("webserver", "WebServer") + + +@click.group("migration") +def migration(): + """Manage database migration""" + pass + + +def add_migration_options(func): + @click.option( + "--alembic_ini_path", + required=False, + type=str, + default=None, + show_default=True, + help="Alembic ini path, if not set, use 'pilot/meta_data/alembic.ini'", + ) + @click.option( + "--script_location", + required=False, + type=str, + default=None, + show_default=True, + help="Alembic script location, if not set, use 'pilot/meta_data/alembic'", + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +@migration.command() +@add_migration_options +@click.option( + "-m", + "--message", + required=False, + type=str, + default="Init migration", + show_default=True, + help="The message for create migration repository", +) +def init(alembic_ini_path: str, script_location: str, message: str): + """Initialize database migration repository""" + from dbgpt.util._db_migration_utils import create_migration_script + + alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) + create_migration_script(alembic_cfg, db_manager.engine, message) + + +@migration.command() +@add_migration_options +@click.option( + "-m", + "--message", + required=False, + type=str, + default="New migration", + show_default=True, + help="The message for migration script", +) +def migrate(alembic_ini_path: str, script_location: str, message: str): + """Create migration script""" + from dbgpt.util._db_migration_utils import create_migration_script + + alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) + create_migration_script(alembic_cfg, db_manager.engine, message) + + +@migration.command() +@add_migration_options +def upgrade(alembic_ini_path: str, script_location: str): + """Upgrade database to target version""" + from dbgpt.util._db_migration_utils import upgrade_database + + alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) + upgrade_database(alembic_cfg, db_manager.engine) + + +@migration.command() +@add_migration_options +@click.option( + "-y", + required=False, + type=bool, + default=False, + is_flag=True, + help="Confirm to downgrade database", +) +@click.option( + "-r", + "--revision", + default="-1", + show_default=True, + help="Revision to downgrade to", +) +def downgrade(alembic_ini_path: str, script_location: str, y: bool, revision: str): + """Downgrade database to target version""" + from dbgpt.util._db_migration_utils import downgrade_database + + if not y: + click.confirm("Are you sure you want to downgrade the database?", abort=True) + alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) + downgrade_database(alembic_cfg, db_manager.engine, revision) + + +@migration.command() +@add_migration_options +@click.option( + "--drop_all_tables", + required=False, + type=bool, + default=False, + is_flag=True, + help="Drop all tables", +) +@click.option( + "-y", + required=False, + type=bool, + default=False, + is_flag=True, + help="Confirm to clean migration data", +) +@click.option( + "--confirm_drop_all_tables", + required=False, + type=bool, + default=False, + is_flag=True, + help="Confirm to drop all tables", +) +def clean( + alembic_ini_path: str, + script_location: str, + drop_all_tables: bool, + y: bool, + confirm_drop_all_tables: bool, +): + """Clean Alembic migration scripts and history""" + from dbgpt.util._db_migration_utils import clean_alembic_migration + + if not y: + click.confirm( + "Are you sure clean alembic migration scripts and history?", abort=True + ) + alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) + clean_alembic_migration(alembic_cfg, db_manager.engine) + if drop_all_tables: + if not confirm_drop_all_tables: + click.confirm("\nAre you sure drop all tables?", abort=True) + with db_manager.engine.connect() as connection: + for tbl in reversed(db_manager.Model.metadata.sorted_tables): + print(f"Drop table {tbl.name}") + connection.execute(tbl.delete()) + + +@migration.command() +@add_migration_options +def list(alembic_ini_path: str, script_location: str): + """List all versions in the migration history, marking the current one""" + from alembic.script import ScriptDirectory + from alembic.runtime.migration import MigrationContext + + alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) + + # Set up Alembic environment and script directory + script = ScriptDirectory.from_config(alembic_cfg) + + # Get current revision + def get_current_revision(): + with db_manager.engine.connect() as connection: + context = MigrationContext.configure(connection) + return context.get_current_revision() + + current_rev = get_current_revision() + + # List all revisions and mark the current one + for revision in script.walk_revisions(): + current_marker = "(current)" if revision.revision == current_rev else "" + print(f"{revision.revision} {current_marker}: {revision.doc}") + + +@migration.command() +@add_migration_options +@click.argument("revision", required=True) +def show(alembic_ini_path: str, script_location: str, revision: str): + """Show the migration script for a specific version.""" + from alembic.script import ScriptDirectory + + alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location) + + script = ScriptDirectory.from_config(alembic_cfg) + + rev = script.get_revision(revision) + if rev is None: + print(f"Revision {revision} not found.") + return + + # Find the migration script file + script_files = os.listdir(os.path.join(script.dir, "versions")) + script_file = next((f for f in script_files if f.startswith(revision)), None) + + if script_file is None: + print(f"Migration script for revision {revision} not found.") + return + # Print the migration script + script_file_path = os.path.join(script.dir, "versions", script_file) + print(f"Migration script for revision {revision}: {script_file_path}") + try: + with open(script_file_path, "r") as file: + print(file.read()) + except FileNotFoundError: + print(f"Migration script {script_file_path} not found.") + + +def _get_migration_config( + alembic_ini_path: Optional[str] = None, script_location: Optional[str] = None +): + from dbgpt.storage.metadata.db_manager import db as db_manager + from dbgpt.util._db_migration_utils import create_alembic_config + + # Must import dbgpt_server for initialize db metadata + from dbgpt.app.dbgpt_server import initialize_app as _ + from dbgpt.app.base import _initialize_db + + # initialize db + default_meta_data_path = _initialize_db() + alembic_cfg = create_alembic_config( + default_meta_data_path, + db_manager.engine, + db_manager.Model, + db_manager.session(), + alembic_ini_path, + script_location, + ) + return alembic_cfg, db_manager diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index fef8792aa..1583b2d3b 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -8,7 +8,8 @@ from dbgpt._private.config import Config from dbgpt.component import SystemApp from dbgpt.util.parameter_utils import BaseParameters -from dbgpt.storage.metadata.meta_data import ddl_init_and_upgrade + +from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -36,8 +37,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp): # init config cfg = Config() cfg.SYSTEM_APP = system_app - - ddl_init_and_upgrade(param.disable_alembic_upgrade) + # Initialize db storage first + _initialize_db_storage(param) # load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) @@ -83,6 +84,46 @@ def startup_event(wh): return startup_event +def _initialize_db_storage(param: "WebServerParameters"): + """Initialize the db storage. + + Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`. + """ + default_meta_data_path = _initialize_db( + try_to_create_db=not param.disable_alembic_upgrade + ) + _ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade) + + +def _initialize_db(try_to_create_db: Optional[bool] = False) -> str: + """Initialize the database + + Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`. + """ + from dbgpt.configs.model_config import PILOT_PATH + from dbgpt.storage.metadata.db_manager import initialize_db + from urllib.parse import quote_plus as urlquote, quote + + CFG = Config() + db_name = CFG.LOCAL_DB_NAME + default_meta_data_path = os.path.join(PILOT_PATH, "meta_data") + os.makedirs(default_meta_data_path, exist_ok=True) + if CFG.LOCAL_DB_TYPE == "mysql": + db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}" + else: + sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db") + db_url = f"sqlite:///{sqlite_db_path}" + engine_args = { + "pool_size": CFG.LOCAL_DB_POOL_SIZE, + "max_overflow": CFG.LOCAL_DB_POOL_OVERFLOW, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_pre_ping": True, + } + initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db) + return default_meta_data_path + + @dataclass class WebServerParameters(BaseParameters): host: Optional[str] = field( diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index cf042d591..dee7a677e 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -13,7 +13,6 @@ if TYPE_CHECKING: from langchain.embeddings.base import Embeddings - logger = logging.getLogger(__name__) CFG = Config() diff --git a/dbgpt/app/knowledge/chunk_db.py b/dbgpt/app/knowledge/chunk_db.py index 209931fec..11dde75b8 100644 --- a/dbgpt/app/knowledge/chunk_db.py +++ b/dbgpt/app/knowledge/chunk_db.py @@ -3,19 +3,13 @@ from sqlalchemy import Column, String, DateTime, Integer, Text, func -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model from dbgpt._private.config import Config CFG = Config() -class DocumentChunkEntity(Base): +class DocumentChunkEntity(Model): __tablename__ = "document_chunk" __table_args__ = { "mysql_charset": "utf8mb4", @@ -35,16 +29,8 @@ def __repr__(self): class DocumentChunkDao(BaseDao): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def create_documents_chunks(self, documents: List): - session = self.get_session() + session = self.get_raw_session() docs = [ DocumentChunkEntity( doc_name=document.doc_name, @@ -64,7 +50,7 @@ def create_documents_chunks(self, documents: List): def get_document_chunks( self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None ): - session = self.get_session() + session = self.get_raw_session() document_chunks = session.query(DocumentChunkEntity) if query.id is not None: document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) @@ -102,7 +88,7 @@ def get_document_chunks( return result def get_document_chunks_count(self, query: DocumentChunkEntity): - session = self.get_session() + session = self.get_raw_session() document_chunks = session.query(func.count(DocumentChunkEntity.id)) if query.id is not None: document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) @@ -127,7 +113,7 @@ def get_document_chunks_count(self, query: DocumentChunkEntity): return count def delete(self, document_id: int): - session = self.get_session() + session = self.get_raw_session() if document_id is None: raise Exception("document_id is None") query = DocumentChunkEntity(document_id=document_id) diff --git a/dbgpt/app/knowledge/document_db.py b/dbgpt/app/knowledge/document_db.py index 3dd8d39fe..4808e7744 100644 --- a/dbgpt/app/knowledge/document_db.py +++ b/dbgpt/app/knowledge/document_db.py @@ -2,19 +2,13 @@ from sqlalchemy import Column, String, DateTime, Integer, Text, func -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model from dbgpt._private.config import Config CFG = Config() -class KnowledgeDocumentEntity(Base): +class KnowledgeDocumentEntity(Model): __tablename__ = "knowledge_document" __table_args__ = { "mysql_charset": "utf8mb4", @@ -39,16 +33,8 @@ def __repr__(self): class KnowledgeDocumentDao(BaseDao): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def create_knowledge_document(self, document: KnowledgeDocumentEntity): - session = self.get_session() + session = self.get_raw_session() knowledge_document = KnowledgeDocumentEntity( doc_name=document.doc_name, doc_type=document.doc_type, @@ -69,7 +55,7 @@ def create_knowledge_document(self, document: KnowledgeDocumentEntity): return doc_id def get_knowledge_documents(self, query, page=1, page_size=20): - session = self.get_session() + session = self.get_raw_session() print(f"current session:{session}") knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: @@ -104,7 +90,7 @@ def get_knowledge_documents(self, query, page=1, page_size=20): return result def get_documents(self, query): - session = self.get_session() + session = self.get_raw_session() print(f"current session:{session}") knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: @@ -136,7 +122,7 @@ def get_documents(self, query): return result def get_knowledge_documents_count_bulk(self, space_names): - session = self.get_session() + session = self.get_raw_session() """ Perform a batch query to count the number of documents for each knowledge space. @@ -161,7 +147,7 @@ def get_knowledge_documents_count_bulk(self, space_names): return docs_count def get_knowledge_documents_count(self, query): - session = self.get_session() + session = self.get_raw_session() knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) if query.id is not None: knowledge_documents = knowledge_documents.filter( @@ -188,14 +174,14 @@ def get_knowledge_documents_count(self, query): return count def update_knowledge_document(self, document: KnowledgeDocumentEntity): - session = self.get_session() + session = self.get_raw_session() updated_space = session.merge(document) session.commit() return updated_space.id # def delete(self, query: KnowledgeDocumentEntity): - session = self.get_session() + session = self.get_raw_session() knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: knowledge_documents = knowledge_documents.filter( diff --git a/dbgpt/app/knowledge/space_db.py b/dbgpt/app/knowledge/space_db.py index 5fc3bc796..8dbd904ab 100644 --- a/dbgpt/app/knowledge/space_db.py +++ b/dbgpt/app/knowledge/space_db.py @@ -2,20 +2,14 @@ from sqlalchemy import Column, Integer, Text, String, DateTime -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model from dbgpt._private.config import Config from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest CFG = Config() -class KnowledgeSpaceEntity(Base): +class KnowledgeSpaceEntity(Model): __tablename__ = "knowledge_space" __table_args__ = { "mysql_charset": "utf8mb4", @@ -35,16 +29,8 @@ def __repr__(self): class KnowledgeSpaceDao(BaseDao): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def create_knowledge_space(self, space: KnowledgeSpaceRequest): - session = self.get_session() + session = self.get_raw_session() knowledge_space = KnowledgeSpaceEntity( name=space.name, vector_type=CFG.VECTOR_STORE_TYPE, @@ -58,7 +44,7 @@ def create_knowledge_space(self, space: KnowledgeSpaceRequest): session.close() def get_knowledge_space(self, query: KnowledgeSpaceEntity): - session = self.get_session() + session = self.get_raw_session() knowledge_spaces = session.query(KnowledgeSpaceEntity) if query.id is not None: knowledge_spaces = knowledge_spaces.filter( @@ -97,14 +83,14 @@ def get_knowledge_space(self, query: KnowledgeSpaceEntity): return result def update_knowledge_space(self, space: KnowledgeSpaceEntity): - session = self.get_session() + session = self.get_raw_session() session.merge(space) session.commit() session.close() return True def delete_knowledge_space(self, space: KnowledgeSpaceEntity): - session = self.get_session() + session = self.get_raw_session() if space: session.delete(space) session.commit() diff --git a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py index b17e5d6ea..999434924 100644 --- a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py +++ b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py @@ -2,17 +2,12 @@ from sqlalchemy import Column, Integer, Text, String, DateTime -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model + from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody -class ChatFeedBackEntity(Base): +class ChatFeedBackEntity(Model): __tablename__ = "chat_feed_back" __table_args__ = { "mysql_charset": "utf8mb4", @@ -39,18 +34,10 @@ def __repr__(self): class ChatFeedBackDao(BaseDao): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): # Todo: We need to have user information first. - session = self.get_session() + session = self.get_raw_session() chat_feed_back = ChatFeedBackEntity( conv_uid=feed_back.conv_uid, conv_index=feed_back.conv_index, @@ -84,7 +71,7 @@ def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): session.close() def get_chat_feed_back(self, conv_uid: str, conv_index: int): - session = self.get_session() + session = self.get_raw_session() result = ( session.query(ChatFeedBackEntity) .filter(ChatFeedBackEntity.conv_uid == conv_uid) diff --git a/dbgpt/app/prompt/prompt_manage_db.py b/dbgpt/app/prompt/prompt_manage_db.py index 55ce05258..0529ad4a7 100644 --- a/dbgpt/app/prompt/prompt_manage_db.py +++ b/dbgpt/app/prompt/prompt_manage_db.py @@ -2,13 +2,8 @@ from sqlalchemy import Column, Integer, Text, String, DateTime -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model + from dbgpt._private.config import Config from dbgpt.app.prompt.request.request import PromptManageRequest @@ -16,7 +11,7 @@ CFG = Config() -class PromptManageEntity(Base): +class PromptManageEntity(Model): __tablename__ = "prompt_manage" __table_args__ = { "mysql_charset": "utf8mb4", @@ -38,16 +33,8 @@ def __repr__(self): class PromptManageDao(BaseDao): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def create_prompt(self, prompt: PromptManageRequest): - session = self.get_session() + session = self.get_raw_session() prompt_manage = PromptManageEntity( chat_scene=prompt.chat_scene, sub_chat_scene=prompt.sub_chat_scene, @@ -64,7 +51,7 @@ def create_prompt(self, prompt: PromptManageRequest): session.close() def get_prompts(self, query: PromptManageEntity): - session = self.get_session() + session = self.get_raw_session() prompts = session.query(PromptManageEntity) if query.chat_scene is not None: prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) @@ -93,13 +80,13 @@ def get_prompts(self, query: PromptManageEntity): return result def update_prompt(self, prompt: PromptManageEntity): - session = self.get_session() + session = self.get_raw_session() session.merge(prompt) session.commit() session.close() def delete_prompt(self, prompt: PromptManageEntity): - session = self.get_session() + session = self.get_raw_session() if prompt: session.delete(prompt) session.commit() diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index 048b04a2f..1e5cf7070 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -146,7 +146,9 @@ async def __call_base(self): input_values = await self.generate_input_values() ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 - self.current_message.add_user_message(self.current_user_input) + self.current_message.add_user_message( + self.current_user_input, check_duplicate_type=True + ) self.current_message.start_date = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) @@ -221,7 +223,7 @@ async def stream_call(self): view_msg = self.stream_plugin_call(msg) view_msg = view_msg.replace("\n", "\\n") yield view_msg - self.current_message.add_ai_message(msg) + self.current_message.add_ai_message(msg, update_if_exist=True) view_msg = self.stream_call_reinforce_fn(view_msg) self.current_message.add_view_message(view_msg) span.end() @@ -257,7 +259,7 @@ async def nostream_call(self): ) ) ### model result deal - self.current_message.add_ai_message(ai_response_text) + self.current_message.add_ai_message(ai_response_text, update_if_exist=True) prompt_define_response = ( self.prompt_template.output_parser.parse_prompt_response( ai_response_text @@ -320,7 +322,7 @@ async def get_llm_response(self): ) ) ### model result deal - self.current_message.add_ai_message(ai_response_text) + self.current_message.add_ai_message(ai_response_text, update_if_exist=True) prompt_define_response = None prompt_define_response = ( self.prompt_template.output_parser.parse_prompt_response( @@ -596,7 +598,7 @@ def _load_system_message( prompt_template: PromptTemplate, str_message: bool = True, ): - system_convs = current_message.get_system_conv() + system_convs = current_message.get_system_messages() system_text = "" system_messages = [] for system_conv in system_convs: @@ -614,7 +616,7 @@ def _load_user_message( prompt_template: PromptTemplate, str_message: bool = True, ): - user_conv = current_message.get_user_conv() + user_conv = current_message.get_latest_user_message() user_messages = [] if user_conv: user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep diff --git a/dbgpt/app/scene/operator/_experimental.py b/dbgpt/app/scene/operator/_experimental.py index 63cc5e194..7ebe2c4c7 100644 --- a/dbgpt/app/scene/operator/_experimental.py +++ b/dbgpt/app/scene/operator/_experimental.py @@ -70,7 +70,9 @@ def __init__( def _new_chat(self, input_values: Dict) -> List[ModelMessage]: self.current_message.chat_order = len(self.history_message) + 1 - self.current_message.add_user_message(self._chat_ctx.current_user_input) + self.current_message.add_user_message( + self._chat_ctx.current_user_input, check_duplicate_type=True + ) self.current_message.start_date = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) diff --git a/dbgpt/cli/cli_scripts.py b/dbgpt/cli/cli_scripts.py index d0c23559a..f0ac7bd03 100644 --- a/dbgpt/cli/cli_scripts.py +++ b/dbgpt/cli/cli_scripts.py @@ -51,6 +51,12 @@ def install(): pass +@click.group() +def db(): + """Manage your metadata database and your datasources.""" + pass + + stop_all_func_list = [] @@ -64,6 +70,7 @@ def stop_all(): cli.add_command(start) cli.add_command(stop) cli.add_command(install) +cli.add_command(db) add_command_alias(stop_all, name="all", parent_group=stop) try: @@ -96,10 +103,13 @@ def stop_all(): start_webserver, stop_webserver, _stop_all_dbgpt_server, + migration, ) add_command_alias(start_webserver, name="webserver", parent_group=start) add_command_alias(stop_webserver, name="webserver", parent_group=stop) + # Add migration command + add_command_alias(migration, name="migration", parent_group=db) stop_all_func_list.append(_stop_all_dbgpt_server) except ImportError as e: diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index d2f64aafd..c9996d3cc 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -9,6 +9,10 @@ ModelMessage, ModelMessageRoleType, OnceConversation, + StorageConversation, + MessageStorageItem, + ConversationIdentifier, + MessageIdentifier, ) from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser @@ -20,6 +24,16 @@ CachePolicy, CacheConfig, ) +from dbgpt.core.interface.storage import ( + ResourceIdentifier, + StorageItem, + StorageItemAdapter, + StorageInterface, + InMemoryStorage, + DefaultStorageItemAdapter, + QuerySpec, + StorageError, +) __ALL__ = [ "ModelInferenceMetrics", @@ -30,6 +44,10 @@ "ModelMessage", "ModelMessageRoleType", "OnceConversation", + "StorageConversation", + "MessageStorageItem", + "ConversationIdentifier", + "MessageIdentifier", "PromptTemplate", "PromptTemplateOperator", "BaseOutputParser", @@ -41,4 +59,12 @@ "CacheClient", "CachePolicy", "CacheConfig", + "ResourceIdentifier", + "StorageItem", + "StorageItemAdapter", + "StorageInterface", + "InMemoryStorage", + "DefaultStorageItemAdapter", + "QuerySpec", + "StorageError", ] diff --git a/dbgpt/core/awel/dag/tests/test_dag.py b/dbgpt/core/awel/dag/tests/test_dag.py index c30530dc8..8a058f318 100644 --- a/dbgpt/core/awel/dag/tests/test_dag.py +++ b/dbgpt/core/awel/dag/tests/test_dag.py @@ -1,7 +1,7 @@ import pytest import threading import asyncio -from ..dag import DAG, DAGContext +from ..base import DAG, DAGVar def test_dag_context_sync(): @@ -9,18 +9,18 @@ def test_dag_context_sync(): dag2 = DAG("dag2") with dag1: - assert DAGContext.get_current_dag() == dag1 + assert DAGVar.get_current_dag() == dag1 with dag2: - assert DAGContext.get_current_dag() == dag2 - assert DAGContext.get_current_dag() == dag1 - assert DAGContext.get_current_dag() is None + assert DAGVar.get_current_dag() == dag2 + assert DAGVar.get_current_dag() == dag1 + assert DAGVar.get_current_dag() is None def test_dag_context_threading(): def thread_function(dag): - DAGContext.enter_dag(dag) - assert DAGContext.get_current_dag() == dag - DAGContext.exit_dag() + DAGVar.enter_dag(dag) + assert DAGVar.get_current_dag() == dag + DAGVar.exit_dag() dag1 = DAG("dag1") dag2 = DAG("dag2") @@ -33,19 +33,19 @@ def thread_function(dag): thread1.join() thread2.join() - assert DAGContext.get_current_dag() is None + assert DAGVar.get_current_dag() is None @pytest.mark.asyncio async def test_dag_context_async(): async def async_function(dag): - DAGContext.enter_dag(dag) - assert DAGContext.get_current_dag() == dag - DAGContext.exit_dag() + DAGVar.enter_dag(dag) + assert DAGVar.get_current_dag() == dag + DAGVar.exit_dag() dag1 = DAG("dag1") dag2 = DAG("dag2") await asyncio.gather(async_function(dag1), async_function(dag2)) - assert DAGContext.get_current_dag() is None + assert DAGVar.get_current_dag() is None diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index c7ae5d5a0..36bba00f4 100644 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -1,16 +1,26 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, Optional from datetime import datetime from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core.interface.storage import ( + ResourceIdentifier, + StorageItem, + StorageInterface, + InMemoryStorage, +) + class BaseMessage(BaseModel, ABC): """Message object.""" content: str + index: int = 0 + round_index: int = 0 + """The round index of the message in the conversation""" additional_kwargs: dict = Field(default_factory=dict) @property @@ -18,6 +28,24 @@ class BaseMessage(BaseModel, ABC): def type(self) -> str: """Type of the message, used for serialization.""" + @property + def pass_to_model(self) -> bool: + """Whether the message will be passed to the model""" + return True + + def to_dict(self) -> Dict: + """Convert to dict + + Returns: + Dict: The dict object + """ + return { + "type": self.type, + "data": self.dict(), + "index": self.index, + "round_index": self.round_index, + } + class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" @@ -51,6 +79,14 @@ def type(self) -> str: """Type of the message, used for serialization.""" return "view" + @property + def pass_to_model(self) -> bool: + """Whether the message will be passed to the model + + The view message will not be passed to the model + """ + return False + class SystemMessage(BaseMessage): """Type of message that is a system message.""" @@ -141,15 +177,15 @@ def build_human_message(content: str) -> "ModelMessage": return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) -def _message_to_dict(message: BaseMessage) -> dict: - return {"type": message.type, "data": message.dict()} +def _message_to_dict(message: BaseMessage) -> Dict: + return message.to_dict() -def _messages_to_dict(messages: List[BaseMessage]) -> List[dict]: +def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]: return [_message_to_dict(m) for m in messages] -def _message_from_dict(message: dict) -> BaseMessage: +def _message_from_dict(message: Dict) -> BaseMessage: _type = message["type"] if _type == "human": return HumanMessage(**message["data"]) @@ -163,7 +199,7 @@ def _message_from_dict(message: dict) -> BaseMessage: raise ValueError(f"Got unexpected type: {_type}") -def _messages_from_dict(messages: List[dict]) -> List[BaseMessage]: +def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]: return [_message_from_dict(m) for m in messages] @@ -193,50 +229,119 @@ def _parse_model_messages( history_messages.append([]) if messages[-1].role != "human": raise ValueError("Hi! What do you want to talk about?") - # Keep message pair of [user message, assistant message] + # Keep message a pair of [user message, assistant message] history_messages = list(filter(lambda x: len(x) == 2, history_messages)) user_prompt = messages[-1].content return user_prompt, system_messages, history_messages class OnceConversation: - """ - All the information of a conversation, the current single service in memory, can expand cache and database support distributed services + """All the information of a conversation, the current single service in memory, + can expand cache and database support distributed services. + """ - def __init__(self, chat_mode, user_name: str = None, sys_code: str = None): + def __init__( + self, + chat_mode: str, + user_name: str = None, + sys_code: str = None, + summary: str = None, + **kwargs, + ): self.chat_mode: str = chat_mode - self.messages: List[BaseMessage] = [] - self.start_date: str = "" - self.chat_order: int = 0 - self.model_name: str = "" - self.param_type: str = "" - self.param_value: str = "" - self.cost: int = 0 - self.tokens: int = 0 self.user_name: str = user_name self.sys_code: str = sys_code + self.summary: str = summary + + self.messages: List[BaseMessage] = kwargs.get("messages", []) + self.start_date: str = kwargs.get("start_date", "") + # After each complete round of dialogue, the current value will be increased by 1 + self.chat_order: int = int(kwargs.get("chat_order", 0)) + self.model_name: str = kwargs.get("model_name", "") + self.param_type: str = kwargs.get("param_type", "") + self.param_value: str = kwargs.get("param_value", "") + self.cost: int = int(kwargs.get("cost", 0)) + self.tokens: int = int(kwargs.get("tokens", 0)) + self._message_index: int = int(kwargs.get("message_index", 0)) + + def _append_message(self, message: BaseMessage) -> None: + index = self._message_index + self._message_index += 1 + message.index = index + message.round_index = self.chat_order + self.messages.append(message) + + def start_new_round(self) -> None: + """Start a new round of conversation + + Example: + >>> conversation = OnceConversation() + >>> # The chat order will be 0, then we start a new round of conversation + >>> assert conversation.chat_order == 0 + >>> conversation.start_new_round() + >>> # Now the chat order will be 1 + >>> assert conversation.chat_order == 1 + >>> conversation.add_user_message("hello") + >>> conversation.add_ai_message("hi") + >>> conversation.end_current_round() + >>> # Now the chat order will be 1, then we start a new round of conversation + >>> conversation.start_new_round() + >>> # Now the chat order will be 2 + >>> assert conversation.chat_order == 2 + >>> conversation.add_user_message("hello") + >>> conversation.add_ai_message("hi") + >>> conversation.end_current_round() + >>> assert conversation.chat_order == 2 + """ + self.chat_order += 1 - def add_user_message(self, message: str) -> None: - """Add a user message to the store""" - has_message = any( - isinstance(instance, HumanMessage) for instance in self.messages - ) - if has_message: - raise ValueError("Already Have Human message") - self.messages.append(HumanMessage(content=message)) + def end_current_round(self) -> None: + """End the current round of conversation - def add_ai_message(self, message: str) -> None: - """Add an AI message to the store""" + We do noting here, just for the interface + """ + pass + + def add_user_message( + self, message: str, check_duplicate_type: Optional[bool] = False + ) -> None: + """Add a user message to the conversation + Args: + message (str): The message content + check_duplicate_type (bool): Whether to check the duplicate message type + + Raises: + ValueError: If the message is duplicate and check_duplicate_type is True + """ + if check_duplicate_type: + has_message = any( + isinstance(instance, HumanMessage) for instance in self.messages + ) + if has_message: + raise ValueError("Already Have Human message") + self._append_message(HumanMessage(content=message)) + + def add_ai_message( + self, message: str, update_if_exist: Optional[bool] = False + ) -> None: + """Add an AI message to the conversation + + Args: + message (str): The message content + update_if_exist (bool): Whether to update the message if the message type is duplicate + """ + if not update_if_exist: + self._append_message(AIMessage(content=message)) + return has_message = any(isinstance(instance, AIMessage) for instance in self.messages) if has_message: - self.__update_ai_message(message) + self._update_ai_message(message) else: - self.messages.append(AIMessage(content=message)) - """ """ + self._append_message(AIMessage(content=message)) - def __update_ai_message(self, new_message: str) -> None: + def _update_ai_message(self, new_message: str) -> None: """ stream out message update Args: @@ -252,13 +357,11 @@ def __update_ai_message(self, new_message: str) -> None: def add_view_message(self, message: str) -> None: """Add an AI message to the store""" - - self.messages.append(ViewMessage(content=message)) - """ """ + self._append_message(ViewMessage(content=message)) def add_system_message(self, message: str) -> None: - """Add an AI message to the store""" - self.messages.append(SystemMessage(content=message)) + """Add a system message to the store""" + self._append_message(SystemMessage(content=message)) def set_start_time(self, datatime: datetime): dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S") @@ -267,23 +370,369 @@ def set_start_time(self, datatime: datetime): def clear(self) -> None: """Remove all messages from the store""" self.messages.clear() - self.session_id = None - def get_user_conv(self): - for message in self.messages: + def get_latest_user_message(self) -> Optional[HumanMessage]: + """Get the latest user message""" + for message in self.messages[::-1]: if isinstance(message, HumanMessage): return message return None - def get_system_conv(self): - system_convs = [] + def get_system_messages(self) -> List[SystemMessage]: + """Get the latest user message""" + return list(filter(lambda x: isinstance(x, SystemMessage), self.messages)) + + def _to_dict(self) -> Dict: + return _conversation_to_dict(self) + + def from_conversation(self, conversation: OnceConversation) -> None: + """Load the conversation from the storage""" + self.chat_mode = conversation.chat_mode + self.messages = conversation.messages + self.start_date = conversation.start_date + self.chat_order = conversation.chat_order + self.model_name = conversation.model_name + self.param_type = conversation.param_type + self.param_value = conversation.param_value + self.cost = conversation.cost + self.tokens = conversation.tokens + self.user_name = conversation.user_name + self.sys_code = conversation.sys_code + + def get_messages_by_round(self, round_index: int) -> List[BaseMessage]: + """Get the messages by round index + + Args: + round_index (int): The round index + + Returns: + List[BaseMessage]: The messages + """ + return list(filter(lambda x: x.round_index == round_index, self.messages)) + + def get_latest_round(self) -> List[BaseMessage]: + """Get the latest round messages + + Returns: + List[BaseMessage]: The messages + """ + return self.get_messages_by_round(self.chat_order) + + def get_messages_with_round(self, round_count: int) -> List[BaseMessage]: + """Get the messages with round count + + If the round count is 1, the history messages will not be included. + + Example: + .. code-block:: python + conversation = OnceConversation() + conversation.start_new_round() + conversation.add_user_message("hello, this is the first round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is the second round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is the third round") + conversation.add_ai_message("hi") + conversation.end_current_round() + + assert len(conversation.get_messages_with_round(1)) == 2 + assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round" + assert conversation.get_messages_with_round(1)[1].content == "hi" + + assert len(conversation.get_messages_with_round(2)) == 4 + assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round" + assert conversation.get_messages_with_round(2)[1].content == "hi" + + Args: + round_count (int): The round count + + Returns: + List[BaseMessage]: The messages + """ + latest_round_index = self.chat_order + start_round_index = max(1, latest_round_index - round_count + 1) + messages = [] + for round_index in range(start_round_index, latest_round_index + 1): + messages.extend(self.get_messages_by_round(round_index)) + return messages + + def get_model_messages(self) -> List[ModelMessage]: + """Get the model messages + + Model messages just include human, ai and system messages. + Model messages maybe include the history messages, The order of the messages is the same as the order of + the messages in the conversation, the last message is the latest message. + + If you want to hand the message with your own logic, you can override this method. + + Examples: + If you not need the history messages, you can override this method like this: + .. code-block:: python + def get_model_messages(self) -> List[ModelMessage]: + messages = [] + for message in self.get_latest_round(): + if message.pass_to_model: + messages.append( + ModelMessage(role=message.type, content=message.content) + ) + return messages + + If you want to add the one round history messages, you can override this method like this: + .. code-block:: python + def get_model_messages(self) -> List[ModelMessage]: + messages = [] + latest_round_index = self.chat_order + round_count = 1 + start_round_index = max(1, latest_round_index - round_count + 1) + for round_index in range(start_round_index, latest_round_index + 1): + for message in self.get_messages_by_round(round_index): + if message.pass_to_model: + messages.append( + ModelMessage(role=message.type, content=message.content) + ) + return messages + + Returns: + List[ModelMessage]: The model messages + """ + messages = [] for message in self.messages: - if isinstance(message, SystemMessage): - system_convs.append(message) - return system_convs + if message.pass_to_model: + messages.append( + ModelMessage(role=message.type, content=message.content) + ) + return messages + + +class ConversationIdentifier(ResourceIdentifier): + """Conversation identifier""" + + def __init__(self, conv_uid: str, identifier_type: str = "conversation"): + self.conv_uid = conv_uid + self.identifier_type = identifier_type + + @property + def str_identifier(self) -> str: + return f"{self.identifier_type}:{self.conv_uid}" + + def to_dict(self) -> Dict: + return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type} + + +class MessageIdentifier(ResourceIdentifier): + """Message identifier""" + + identifier_split = "___" + + def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"): + self.conv_uid = conv_uid + self.index = index + self.identifier_type = identifier_type + + @property + def str_identifier(self) -> str: + return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}" + + @staticmethod + def from_str_identifier(str_identifier: str) -> MessageIdentifier: + """Convert from str identifier + + Args: + str_identifier (str): The str identifier + + Returns: + MessageIdentifier: The message identifier + """ + parts = str_identifier.split(MessageIdentifier.identifier_split) + if len(parts) != 3: + raise ValueError(f"Invalid str identifier: {str_identifier}") + return MessageIdentifier(parts[1], int(parts[2])) + + def to_dict(self) -> Dict: + return { + "conv_uid": self.conv_uid, + "index": self.index, + "identifier_type": self.identifier_type, + } + + +class MessageStorageItem(StorageItem): + @property + def identifier(self) -> MessageIdentifier: + return self._id + + def __init__(self, conv_uid: str, index: int, message_detail: Dict): + self.conv_uid = conv_uid + self.index = index + self.message_detail = message_detail + self._id = MessageIdentifier(conv_uid, index) + + def to_dict(self) -> Dict: + return { + "conv_uid": self.conv_uid, + "index": self.index, + "message_detail": self.message_detail, + } + + def to_message(self) -> BaseMessage: + """Convert to message object + Returns: + BaseMessage: The message object + + Raises: + ValueError: If the message type is not supported + """ + return _message_from_dict(self.message_detail) + + def merge(self, other: "StorageItem") -> None: + """Merge the other message to self + + Args: + other (StorageItem): The other message + """ + if not isinstance(other, MessageStorageItem): + raise ValueError(f"Can not merge {other} to {self}") + self.message_detail = other.message_detail + + +class StorageConversation(OnceConversation, StorageItem): + """All the information of a conversation, the current single service in memory, + can expand cache and database support distributed services. + + """ + + @property + def identifier(self) -> ConversationIdentifier: + return self._id + + def to_dict(self) -> Dict: + dict_data = self._to_dict() + messages: Dict = dict_data.pop("messages") + message_ids = [] + index = 0 + for message in messages: + if "index" in message: + message_idx = message["index"] + else: + message_idx = index + index += 1 + message_ids.append( + MessageIdentifier(self.conv_uid, message_idx).str_identifier + ) + # Replace message with message ids + dict_data["conv_uid"] = self.conv_uid + dict_data["message_ids"] = message_ids + dict_data["save_message_independent"] = self.save_message_independent + return dict_data + + def merge(self, other: "StorageItem") -> None: + """Merge the other conversation to self + + Args: + other (StorageItem): The other conversation + """ + if not isinstance(other, StorageConversation): + raise ValueError(f"Can not merge {other} to {self}") + self.from_conversation(other) + + def __init__( + self, + conv_uid: str, + chat_mode: str = None, + user_name: str = None, + sys_code: str = None, + message_ids: List[str] = None, + summary: str = None, + save_message_independent: Optional[bool] = True, + conv_storage: StorageInterface = None, + message_storage: StorageInterface = None, + **kwargs, + ): + super().__init__(chat_mode, user_name, sys_code, summary, **kwargs) + self.conv_uid = conv_uid + self._message_ids = message_ids + self.save_message_independent = save_message_independent + self._id = ConversationIdentifier(conv_uid) + if conv_storage is None: + conv_storage = InMemoryStorage() + if message_storage is None: + message_storage = InMemoryStorage() + self.conv_storage = conv_storage + self.message_storage = message_storage + # Load from storage + self.load_from_storage(self.conv_storage, self.message_storage) + + @property + def message_ids(self) -> List[str]: + """Get the message ids + + Returns: + List[str]: The message ids + """ + return self._message_ids if self._message_ids else [] + + def end_current_round(self) -> None: + """End the current round of conversation + + Save the conversation to the storage after a round of conversation + """ + self.save_to_storage() + + def _get_message_items(self) -> List[MessageStorageItem]: + return [ + MessageStorageItem(self.conv_uid, message.index, message.to_dict()) + for message in self.messages + ] + + def save_to_storage(self) -> None: + """Save the conversation to the storage""" + # Save messages first + message_list = self._get_message_items() + self._message_ids = [ + message.identifier.str_identifier for message in message_list + ] + self.message_storage.save_list(message_list) + # Save conversation + self.conv_storage.save_or_update(self) + + def load_from_storage( + self, conv_storage: StorageInterface, message_storage: StorageInterface + ) -> None: + """Load the conversation from the storage + + Warning: This will overwrite the current conversation. + + Args: + conv_storage (StorageInterface): The storage interface + message_storage (StorageInterface): The storage interface + """ + # Load conversation first + conversation: StorageConversation = conv_storage.load( + self._id, StorageConversation + ) + if conversation is None: + return + message_ids = conversation._message_ids or [] + + # Load messages + message_list = message_storage.load_list( + [ + MessageIdentifier.from_str_identifier(message_id) + for message_id in message_ids + ], + MessageStorageItem, + ) + messages = [message.to_message() for message in message_list] + conversation.messages = messages + self._message_ids = message_ids + self.from_conversation(conversation) -def _conversation_to_dict(once: OnceConversation) -> dict: +def _conversation_to_dict(once: OnceConversation) -> Dict: start_str: str = "" if hasattr(once, "start_date") and once.start_date: if isinstance(once.start_date, datetime): @@ -303,6 +752,7 @@ def _conversation_to_dict(once: OnceConversation) -> dict: "param_value": once.param_value, "user_name": once.user_name, "sys_code": once.sys_code, + "summary": once.summary if once.summary else "", } diff --git a/dbgpt/core/interface/output_parser.py b/dbgpt/core/interface/output_parser.py index ffea7f094..c77511a06 100644 --- a/dbgpt/core/interface/output_parser.py +++ b/dbgpt/core/interface/output_parser.py @@ -92,7 +92,7 @@ def parse_model_nostream_resp(self, response: ResponseTye, sep: str): f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}""" ) - def __illegal_json_ends(self, s): + def _illegal_json_ends(self, s): temp_json = s illegal_json_ends_1 = [", }", ",}"] illegal_json_ends_2 = ", ]", ",]" @@ -102,25 +102,25 @@ def __illegal_json_ends(self, s): temp_json = temp_json.replace(illegal_json_end, " ]") return temp_json - def __extract_json(self, s): + def _extract_json(self, s): try: # Get the dual-mode analysis first and get the maximum result - temp_json_simple = self.__json_interception(s) - temp_json_array = self.__json_interception(s, True) + temp_json_simple = self._json_interception(s) + temp_json_array = self._json_interception(s, True) if len(temp_json_simple) > len(temp_json_array): temp_json = temp_json_simple else: temp_json = temp_json_array if not temp_json: - temp_json = self.__json_interception(s) + temp_json = self._json_interception(s) - temp_json = self.__illegal_json_ends(temp_json) + temp_json = self._illegal_json_ends(temp_json) return temp_json except Exception as e: raise ValueError("Failed to find a valid json in LLM response!" + temp_json) - def __json_interception(self, s, is_json_array: bool = False): + def _json_interception(self, s, is_json_array: bool = False): try: if is_json_array: i = s.find("[") @@ -176,7 +176,7 @@ def parse_prompt_response(self, model_out_text) -> T: cleaned_output = cleaned_output.strip() if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): logger.info("illegal json processing:\n" + cleaned_output) - cleaned_output = self.__extract_json(cleaned_output) + cleaned_output = self._extract_json(cleaned_output) if not cleaned_output or len(cleaned_output) <= 0: return model_out_text @@ -188,7 +188,7 @@ def parse_prompt_response(self, model_out_text) -> T: .replace("\\", " ") .replace("\_", "_") ) - cleaned_output = self.__illegal_json_ends(cleaned_output) + cleaned_output = self._illegal_json_ends(cleaned_output) return cleaned_output def parse_view_response( @@ -208,20 +208,6 @@ def get_format_instructions(self) -> str: """Instructions on how the LLM output should be formatted.""" raise NotImplementedError - # @property - # def _type(self) -> str: - # """Return the type key.""" - # raise NotImplementedError( - # f"_type property is not implemented in class {self.__class__.__name__}." - # " This is required for serialization." - # ) - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of output parser.""" - output_parser_dict = super().dict() - output_parser_dict["_type"] = self._type - return output_parser_dict - async def map(self, input_value: ModelOutput) -> Any: """Parse the output of an LLM call. diff --git a/dbgpt/core/interface/serialization.py b/dbgpt/core/interface/serialization.py index fd26e3883..e26ec5735 100644 --- a/dbgpt/core/interface/serialization.py +++ b/dbgpt/core/interface/serialization.py @@ -1,19 +1,34 @@ +from __future__ import annotations from abc import ABC, abstractmethod from typing import Type, Dict class Serializable(ABC): + serializer: "Serializer" = None + @abstractmethod + def to_dict(self) -> Dict: + """Convert the object's state to a dictionary.""" + def serialize(self) -> bytes: """Convert the object into bytes for storage or transmission. Returns: bytes: The byte array after serialization """ + if self.serializer is None: + raise ValueError( + "Serializer is not set. Please set the serializer before serialization." + ) + return self.serializer.serialize(self) - @abstractmethod - def to_dict(self) -> Dict: - """Convert the object's state to a dictionary.""" + def set_serializer(self, serializer: "Serializer") -> None: + """Set the serializer for current serializable object. + + Args: + serializer (Serializer): The serializer to set + """ + self.serializer = serializer class Serializer(ABC): diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py new file mode 100644 index 000000000..f8e722f59 --- /dev/null +++ b/dbgpt/core/interface/storage.py @@ -0,0 +1,409 @@ +from typing import Generic, TypeVar, Type, Optional, Dict, Any, List +from abc import ABC, abstractmethod + +from dbgpt.core.interface.serialization import Serializable, Serializer +from dbgpt.util.serialization.json_serialization import JsonSerializer +from dbgpt.util.annotations import PublicAPI +from dbgpt.util.pagination_utils import PaginationResult + + +@PublicAPI(stability="beta") +class ResourceIdentifier(Serializable, ABC): + """The resource identifier interface for resource identifiers.""" + + @property + @abstractmethod + def str_identifier(self) -> str: + """Get the string identifier of the resource. + + The string identifier is used to uniquely identify the resource. + + Returns: + str: The string identifier of the resource + """ + + def __hash__(self) -> int: + """Return the hash value of the key.""" + return hash(self.str_identifier) + + def __eq__(self, other: Any) -> bool: + """Check equality with another key.""" + if not isinstance(other, ResourceIdentifier): + return False + return self.str_identifier == other.str_identifier + + +@PublicAPI(stability="beta") +class StorageItem(Serializable, ABC): + """The storage item interface for storage items.""" + + @property + @abstractmethod + def identifier(self) -> ResourceIdentifier: + """Get the resource identifier of the storage item. + + Returns: + ResourceIdentifier: The resource identifier of the storage item + """ + + @abstractmethod + def merge(self, other: "StorageItem") -> None: + """Merge the other storage item into the current storage item. + + Args: + other (StorageItem): The other storage item + """ + + +T = TypeVar("T", bound=StorageItem) +TDataRepresentation = TypeVar("TDataRepresentation") + + +class StorageItemAdapter(Generic[T, TDataRepresentation]): + """The storage item adapter for converting storage items to and from the storage format. + + Sometimes, the storage item is not the same as the storage format, + so we need to convert the storage item to the storage format and vice versa. + + In database storage, the storage format is database model, but the StorageItem is the user-defined object. + """ + + @abstractmethod + def to_storage_format(self, item: T) -> TDataRepresentation: + """Convert the storage item to the storage format. + + Args: + item (T): The storage item + + Returns: + TDataRepresentation: The data in the storage format + """ + + @abstractmethod + def from_storage_format(self, data: TDataRepresentation) -> T: + """Convert the storage format to the storage item. + + Args: + data (TDataRepresentation): The data in the storage format + + Returns: + T: The storage item + """ + + @abstractmethod + def get_query_for_identifier( + self, + storage_format: Type[TDataRepresentation], + resource_id: ResourceIdentifier, + **kwargs, + ) -> Any: + """Get the query for the resource identifier. + + Args: + storage_format (Type[TDataRepresentation]): The storage format + resource_id (ResourceIdentifier): The resource identifier + kwargs: The additional arguments + + Returns: + Any: The query for the resource identifier + """ + + +class DefaultStorageItemAdapter(StorageItemAdapter[T, T]): + """The default storage item adapter for converting storage items to and from the storage format. + + The storage item is the same as the storage format, so no conversion is required. + """ + + def to_storage_format(self, item: T) -> T: + return item + + def from_storage_format(self, data: T) -> T: + return data + + def get_query_for_identifier( + self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs + ) -> bool: + return True + + +@PublicAPI(stability="beta") +class StorageError(Exception): + """The base exception class for storage errors.""" + + def __init__(self, message: str): + super().__init__(message) + + +@PublicAPI(stability="beta") +class QuerySpec: + """The query specification for querying data from the storage. + + Attributes: + conditions (Dict[str, Any]): The conditions for querying data + limit (int): The maximum number of data to return + offset (int): The offset of the data to return + """ + + def __init__( + self, conditions: Dict[str, Any], limit: int = None, offset: int = 0 + ) -> None: + self.conditions = conditions + self.limit = limit + self.offset = offset + + +@PublicAPI(stability="beta") +class StorageInterface(Generic[T, TDataRepresentation], ABC): + """The storage interface for storing and loading data.""" + + def __init__( + self, + serializer: Optional[Serializer] = None, + adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None, + ): + self._serializer = serializer or JsonSerializer() + self._storage_item_adapter = adapter or DefaultStorageItemAdapter() + + @property + def serializer(self) -> Serializer: + """Get the serializer of the storage. + + Returns: + Serializer: The serializer of the storage + """ + return self._serializer + + @property + def adapter(self) -> StorageItemAdapter[T, TDataRepresentation]: + """Get the adapter of the storage. + + Returns: + StorageItemAdapter[T, TDataRepresentation]: The adapter of the storage + """ + return self._storage_item_adapter + + @abstractmethod + def save(self, data: T) -> None: + """Save the data to the storage. + + Args: + data (T): The data to save + + Raises: + StorageError: If the data already exists in the storage or data is None + """ + + @abstractmethod + def update(self, data: T) -> None: + """Update the data to the storage. + + Args: + data (T): The data to save + + Raises: + StorageError: If data is None + """ + + @abstractmethod + def save_or_update(self, data: T) -> None: + """Save or update the data to the storage. + + Args: + data (T): The data to save + + Raises: + StorageError: If data is None + """ + + def save_list(self, data: List[T]) -> None: + """Save the data to the storage. + + Args: + data (T): The data to save + + Raises: + StorageError: If the data already exists in the storage or data is None + """ + for d in data: + self.save(d) + + def save_or_update_list(self, data: List[T]) -> None: + """Save or update the data to the storage. + + Args: + data (T): The data to save + """ + for d in data: + self.save_or_update(d) + + @abstractmethod + def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]: + """Load the data from the storage. + + None will be returned if the data does not exist in the storage. + + Load data with resource_id will be faster than query data with conditions, + so we suggest to use load if possible. + + Args: + resource_id (ResourceIdentifier): The resource identifier of the data + cls (Type[T]): The type of the data + + Returns: + Optional[T]: The loaded data + """ + + def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]: + """Load the data from the storage. + + None will be returned if the data does not exist in the storage. + + Load data with resource_id will be faster than query data with conditions, + so we suggest to use load if possible. + + Args: + resource_id (ResourceIdentifier): The resource identifier of the data + cls (Type[T]): The type of the data + + Returns: + Optional[T]: The loaded data + """ + result = [] + for r in resource_id: + item = self.load(r, cls) + if item is not None: + result.append(item) + return result + + @abstractmethod + def delete(self, resource_id: ResourceIdentifier) -> None: + """Delete the data from the storage. + + Args: + resource_id (ResourceIdentifier): The resource identifier of the data + """ + + @abstractmethod + def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: + """Query data from the storage. + + Query data with resource_id will be faster than query data with conditions, so please use load if possible. + + Args: + spec (QuerySpec): The query specification + cls (Type[T]): The type of the data + + Returns: + List[T]: The queried data + """ + + @abstractmethod + def count(self, spec: QuerySpec, cls: Type[T]) -> int: + """Count the number of data from the storage. + + Args: + spec (QuerySpec): The query specification + cls (Type[T]): The type of the data + + Returns: + int: The number of data + """ + + def paginate_query( + self, page: int, page_size: int, cls: Type[T], spec: Optional[QuerySpec] = None + ) -> PaginationResult[T]: + """Paginate the query result. + + Args: + page (int): The page number + page_size (int): The number of items per page + cls (Type[T]): The type of the data + spec (Optional[QuerySpec], optional): The query specification. Defaults to None. + + Returns: + PaginationResult[T]: The pagination result + """ + if spec is None: + spec = QuerySpec(conditions={}) + spec.limit = page_size + spec.offset = (page - 1) * page_size + items = self.query(spec, cls) + total = self.count(spec, cls) + return PaginationResult( + items=items, + total_count=total, + total_pages=(total + page_size - 1) // page_size, + page=page, + page_size=page_size, + ) + + +@PublicAPI(stability="alpha") +class InMemoryStorage(StorageInterface[T, T]): + """The in-memory storage for storing and loading data.""" + + def __init__( + self, + serializer: Optional[Serializer] = None, + ): + super().__init__(serializer) + self._data = {} # Key: ResourceIdentifier, Value: Serialized data + + def save(self, data: T) -> None: + if not data: + raise StorageError("Data cannot be None") + if not data.serializer: + data.set_serializer(self.serializer) + + if data.identifier.str_identifier in self._data: + raise StorageError( + f"Data with identifier {data.identifier.str_identifier} already exists" + ) + self._data[data.identifier.str_identifier] = data.serialize() + + def update(self, data: T) -> None: + if not data: + raise StorageError("Data cannot be None") + if not data.serializer: + data.set_serializer(self.serializer) + self._data[data.identifier.str_identifier] = data.serialize() + + def save_or_update(self, data: T) -> None: + self.update(data) + + def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]: + serialized_data = self._data.get(resource_id.str_identifier) + if serialized_data is None: + return None + return self.serializer.deserialize(serialized_data, cls) + + def delete(self, resource_id: ResourceIdentifier) -> None: + if resource_id.str_identifier in self._data: + del self._data[resource_id.str_identifier] + + def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: + result = [] + for serialized_data in self._data.values(): + data = self._serializer.deserialize(serialized_data, cls) + if all( + getattr(data, key) == value for key, value in spec.conditions.items() + ): + result.append(data) + + # Apply limit and offset + if spec.limit is not None: + result = result[spec.offset : spec.offset + spec.limit] + else: + result = result[spec.offset :] + return result + + def count(self, spec: QuerySpec, cls: Type[T]) -> int: + count = 0 + for serialized_data in self._data.values(): + data = self._serializer.deserialize(serialized_data, cls) + if all( + getattr(data, key) == value for key, value in spec.conditions.items() + ): + count += 1 + return count diff --git a/dbgpt/core/interface/tests/__init__.py b/dbgpt/core/interface/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/core/interface/tests/conftest.py b/dbgpt/core/interface/tests/conftest.py new file mode 100644 index 000000000..c7afff909 --- /dev/null +++ b/dbgpt/core/interface/tests/conftest.py @@ -0,0 +1,14 @@ +import pytest + +from dbgpt.core.interface.storage import InMemoryStorage +from dbgpt.util.serialization.json_serialization import JsonSerializer + + +@pytest.fixture +def serializer(): + return JsonSerializer() + + +@pytest.fixture +def in_memory_storage(serializer): + return InMemoryStorage(serializer) diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py new file mode 100644 index 000000000..425f268af --- /dev/null +++ b/dbgpt/core/interface/tests/test_message.py @@ -0,0 +1,307 @@ +import pytest + +from dbgpt.core.interface.tests.conftest import in_memory_storage +from dbgpt.core.interface.message import * + + +@pytest.fixture +def basic_conversation(): + return OnceConversation(chat_mode="chat_normal", user_name="user1", sys_code="sys1") + + +@pytest.fixture +def human_message(): + return HumanMessage(content="Hello") + + +@pytest.fixture +def ai_message(): + return AIMessage(content="Hi there") + + +@pytest.fixture +def system_message(): + return SystemMessage(content="System update") + + +@pytest.fixture +def view_message(): + return ViewMessage(content="View this") + + +@pytest.fixture +def conversation_identifier(): + return ConversationIdentifier("conv1") + + +@pytest.fixture +def message_identifier(): + return MessageIdentifier("conv1", 1) + + +@pytest.fixture +def message_storage_item(): + message = HumanMessage(content="Hello", index=1) + message_detail = message.to_dict() + return MessageStorageItem("conv1", 1, message_detail) + + +@pytest.fixture +def storage_conversation(): + return StorageConversation("conv1", chat_mode="chat_normal", user_name="user1") + + +@pytest.fixture +def conversation_with_messages(): + conv = OnceConversation(chat_mode="chat_normal", user_name="user1") + conv.start_new_round() + conv.add_user_message("Hello") + conv.add_ai_message("Hi") + conv.end_current_round() + + conv.start_new_round() + conv.add_user_message("How are you?") + conv.add_ai_message("I'm good, thanks") + conv.end_current_round() + + return conv + + +def test_init(basic_conversation): + assert basic_conversation.chat_mode == "chat_normal" + assert basic_conversation.user_name == "user1" + assert basic_conversation.sys_code == "sys1" + assert basic_conversation.messages == [] + assert basic_conversation.start_date == "" + assert basic_conversation.chat_order == 0 + assert basic_conversation.model_name == "" + assert basic_conversation.param_type == "" + assert basic_conversation.param_value == "" + assert basic_conversation.cost == 0 + assert basic_conversation.tokens == 0 + assert basic_conversation._message_index == 0 + + +def test_add_user_message(basic_conversation, human_message): + basic_conversation.add_user_message(human_message.content) + assert len(basic_conversation.messages) == 1 + assert isinstance(basic_conversation.messages[0], HumanMessage) + + +def test_add_ai_message(basic_conversation, ai_message): + basic_conversation.add_ai_message(ai_message.content) + assert len(basic_conversation.messages) == 1 + assert isinstance(basic_conversation.messages[0], AIMessage) + + +def test_add_system_message(basic_conversation, system_message): + basic_conversation.add_system_message(system_message.content) + assert len(basic_conversation.messages) == 1 + assert isinstance(basic_conversation.messages[0], SystemMessage) + + +def test_add_view_message(basic_conversation, view_message): + basic_conversation.add_view_message(view_message.content) + assert len(basic_conversation.messages) == 1 + assert isinstance(basic_conversation.messages[0], ViewMessage) + + +def test_set_start_time(basic_conversation): + now = datetime.now() + basic_conversation.set_start_time(now) + assert basic_conversation.start_date == now.strftime("%Y-%m-%d %H:%M:%S") + + +def test_clear_messages(basic_conversation, human_message): + basic_conversation.add_user_message(human_message.content) + basic_conversation.clear() + assert len(basic_conversation.messages) == 0 + + +def test_get_latest_user_message(basic_conversation, human_message): + basic_conversation.add_user_message(human_message.content) + latest_message = basic_conversation.get_latest_user_message() + assert latest_message == human_message + + +def test_get_system_messages(basic_conversation, system_message): + basic_conversation.add_system_message(system_message.content) + system_messages = basic_conversation.get_system_messages() + assert len(system_messages) == 1 + assert system_messages[0] == system_message + + +def test_from_conversation(basic_conversation): + new_conversation = OnceConversation(chat_mode="chat_advanced", user_name="user2") + basic_conversation.from_conversation(new_conversation) + assert basic_conversation.chat_mode == "chat_advanced" + assert basic_conversation.user_name == "user2" + + +def test_get_messages_by_round(conversation_with_messages): + # Test first round + round1_messages = conversation_with_messages.get_messages_by_round(1) + assert len(round1_messages) == 2 + assert round1_messages[0].content == "Hello" + assert round1_messages[1].content == "Hi" + + # Test not existing round + no_messages = conversation_with_messages.get_messages_by_round(3) + assert len(no_messages) == 0 + + +def test_get_latest_round(conversation_with_messages): + latest_round_messages = conversation_with_messages.get_latest_round() + assert len(latest_round_messages) == 2 + assert latest_round_messages[0].content == "How are you?" + assert latest_round_messages[1].content == "I'm good, thanks" + + +def test_get_messages_with_round(conversation_with_messages): + # Test last round + last_round_messages = conversation_with_messages.get_messages_with_round(1) + assert len(last_round_messages) == 2 + assert last_round_messages[0].content == "How are you?" + assert last_round_messages[1].content == "I'm good, thanks" + + # Test last two rounds + last_two_rounds_messages = conversation_with_messages.get_messages_with_round(2) + assert len(last_two_rounds_messages) == 4 + assert last_two_rounds_messages[0].content == "Hello" + assert last_two_rounds_messages[1].content == "Hi" + + +def test_get_model_messages(conversation_with_messages): + model_messages = conversation_with_messages.get_model_messages() + assert len(model_messages) == 4 + assert all(isinstance(msg, ModelMessage) for msg in model_messages) + assert model_messages[0].content == "Hello" + assert model_messages[1].content == "Hi" + assert model_messages[2].content == "How are you?" + assert model_messages[3].content == "I'm good, thanks" + + +def test_conversation_identifier(conversation_identifier): + assert conversation_identifier.conv_uid == "conv1" + assert conversation_identifier.identifier_type == "conversation" + assert conversation_identifier.str_identifier == "conversation:conv1" + assert conversation_identifier.to_dict() == { + "conv_uid": "conv1", + "identifier_type": "conversation", + } + + +def test_message_identifier(message_identifier): + assert message_identifier.conv_uid == "conv1" + assert message_identifier.index == 1 + assert message_identifier.identifier_type == "message" + assert message_identifier.str_identifier == "message___conv1___1" + assert message_identifier.to_dict() == { + "conv_uid": "conv1", + "index": 1, + "identifier_type": "message", + } + + +def test_message_storage_item(message_storage_item): + assert message_storage_item.conv_uid == "conv1" + assert message_storage_item.index == 1 + assert message_storage_item.message_detail == { + "type": "human", + "data": { + "content": "Hello", + "index": 1, + "round_index": 0, + "additional_kwargs": {}, + "example": False, + }, + "index": 1, + "round_index": 0, + } + + assert isinstance(message_storage_item.identifier, MessageIdentifier) + assert message_storage_item.to_dict() == { + "conv_uid": "conv1", + "index": 1, + "message_detail": { + "type": "human", + "index": 1, + "data": { + "content": "Hello", + "index": 1, + "round_index": 0, + "additional_kwargs": {}, + "example": False, + }, + "round_index": 0, + }, + } + + assert isinstance(message_storage_item.to_message(), BaseMessage) + + +def test_storage_conversation_init(storage_conversation): + assert storage_conversation.conv_uid == "conv1" + assert storage_conversation.chat_mode == "chat_normal" + assert storage_conversation.user_name == "user1" + + +def test_storage_conversation_add_user_message(storage_conversation): + storage_conversation.add_user_message("Hi") + assert len(storage_conversation.messages) == 1 + assert isinstance(storage_conversation.messages[0], HumanMessage) + + +def test_storage_conversation_add_ai_message(storage_conversation): + storage_conversation.add_ai_message("Hello") + assert len(storage_conversation.messages) == 1 + assert isinstance(storage_conversation.messages[0], AIMessage) + + +def test_save_to_storage(storage_conversation, in_memory_storage): + # Set storage + storage_conversation.conv_storage = in_memory_storage + storage_conversation.message_storage = in_memory_storage + + # Add messages + storage_conversation.add_user_message("User message") + storage_conversation.add_ai_message("AI response") + + # Save to storage + storage_conversation.save_to_storage() + + # Create a new StorageConversation instance to load the data + saved_conversation = StorageConversation( + storage_conversation.conv_uid, + conv_storage=in_memory_storage, + message_storage=in_memory_storage, + ) + + assert saved_conversation.conv_uid == storage_conversation.conv_uid + assert len(saved_conversation.messages) == 2 + assert isinstance(saved_conversation.messages[0], HumanMessage) + assert isinstance(saved_conversation.messages[1], AIMessage) + + +def test_load_from_storage(storage_conversation, in_memory_storage): + # Set storage + storage_conversation.conv_storage = in_memory_storage + storage_conversation.message_storage = in_memory_storage + + # Add messages and save to storage + storage_conversation.add_user_message("User message") + storage_conversation.add_ai_message("AI response") + storage_conversation.save_to_storage() + + # Create a new StorageConversation instance to load the data + new_conversation = StorageConversation( + "conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage + ) + + # Check if the data is loaded correctly + assert new_conversation.conv_uid == storage_conversation.conv_uid + assert len(new_conversation.messages) == 2 + assert new_conversation.messages[0].content == "User message" + assert new_conversation.messages[1].content == "AI response" + assert isinstance(new_conversation.messages[0], HumanMessage) + assert isinstance(new_conversation.messages[1], AIMessage) diff --git a/dbgpt/core/interface/tests/test_storage.py b/dbgpt/core/interface/tests/test_storage.py new file mode 100644 index 000000000..74864e02f --- /dev/null +++ b/dbgpt/core/interface/tests/test_storage.py @@ -0,0 +1,129 @@ +import pytest +from typing import Dict, Type, Union +from dbgpt.core.interface.storage import ( + ResourceIdentifier, + StorageError, + QuerySpec, + InMemoryStorage, + StorageItem, +) +from dbgpt.util.serialization.json_serialization import JsonSerializer + + +class MockResourceIdentifier(ResourceIdentifier): + def __init__(self, identifier: str): + self._identifier = identifier + + @property + def str_identifier(self) -> str: + return self._identifier + + def to_dict(self) -> Dict: + return {"identifier": self._identifier} + + +class MockStorageItem(StorageItem): + def merge(self, other: "StorageItem") -> None: + if not isinstance(other, MockStorageItem): + raise ValueError("other must be a MockStorageItem") + self.data = other.data + + def __init__(self, identifier: Union[str, MockResourceIdentifier], data): + self._identifier_str = ( + identifier if isinstance(identifier, str) else identifier.str_identifier + ) + self.data = data + + def to_dict(self) -> Dict: + return {"identifier": self._identifier_str, "data": self.data} + + @property + def identifier(self) -> ResourceIdentifier: + return MockResourceIdentifier(self._identifier_str) + + +@pytest.fixture +def serializer(): + return JsonSerializer() + + +@pytest.fixture +def in_memory_storage(serializer): + return InMemoryStorage(serializer) + + +def test_save_and_load(in_memory_storage): + resource_id = MockResourceIdentifier("1") + item = MockStorageItem(resource_id, "test_data") + + in_memory_storage.save(item) + + loaded_item = in_memory_storage.load(resource_id, MockStorageItem) + assert loaded_item.data == "test_data" + + +def test_duplicate_save(in_memory_storage): + item = MockStorageItem("1", "test_data") + + in_memory_storage.save(item) + + # Should raise StorageError when saving the same data + with pytest.raises(StorageError): + in_memory_storage.save(item) + + +def test_delete(in_memory_storage): + resource_id = MockResourceIdentifier("1") + item = MockStorageItem(resource_id, "test_data") + + in_memory_storage.save(item) + in_memory_storage.delete(resource_id) + # Storage should not contain the data after deletion + assert in_memory_storage.load(resource_id, MockStorageItem) is None + + +def test_query(in_memory_storage): + resource_id1 = MockResourceIdentifier("1") + item1 = MockStorageItem(resource_id1, "test_data1") + + resource_id2 = MockResourceIdentifier("2") + item2 = MockStorageItem(resource_id2, "test_data2") + + in_memory_storage.save(item1) + in_memory_storage.save(item2) + + query_spec = QuerySpec(conditions={"data": "test_data1"}) + results = in_memory_storage.query(query_spec, MockStorageItem) + assert len(results) == 1 + assert results[0].data == "test_data1" + + +def test_count(in_memory_storage): + item1 = MockStorageItem("1", "test_data1") + + item2 = MockStorageItem("2", "test_data2") + + in_memory_storage.save(item1) + in_memory_storage.save(item2) + + query_spec = QuerySpec(conditions={}) + count = in_memory_storage.count(query_spec, MockStorageItem) + assert count == 2 + + +def test_paginate_query(in_memory_storage): + for i in range(10): + resource_id = MockResourceIdentifier(str(i)) + item = MockStorageItem(resource_id, f"test_data{i}") + in_memory_storage.save(item) + + page_size = 3 + query_spec = QuerySpec(conditions={}) + page_result = in_memory_storage.paginate_query( + 2, page_size, MockStorageItem, query_spec + ) + + assert len(page_result.items) == page_size + assert page_result.total_count == 10 + assert page_result.total_pages == 4 + assert page_result.page == 2 diff --git a/dbgpt/datasource/base.py b/dbgpt/datasource/base.py index 58f5f131c..2d3f7c223 100644 --- a/dbgpt/datasource/base.py +++ b/dbgpt/datasource/base.py @@ -91,6 +91,10 @@ def get_fields(self, table_name): """Get column fields about specified table.""" pass + def get_simple_fields(self, table_name): + """Get column fields about specified table.""" + return self.get_fields(table_name) + def get_show_create_table(self, table_name): """Get the creation table sql about specified table.""" pass diff --git a/dbgpt/datasource/manages/connect_config_db.py b/dbgpt/datasource/manages/connect_config_db.py index 225bfdca9..436f8f21f 100644 --- a/dbgpt/datasource/manages/connect_config_db.py +++ b/dbgpt/datasource/manages/connect_config_db.py @@ -1,16 +1,10 @@ from sqlalchemy import Column, Integer, String, Index, Text, text from sqlalchemy import UniqueConstraint -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model -class ConnectConfigEntity(Base): +class ConnectConfigEntity(Model): """db connect config entity""" __tablename__ = "connect_config" @@ -38,17 +32,9 @@ class ConnectConfigEntity(Base): class ConnectConfigDao(BaseDao[ConnectConfigEntity]): """db connect config dao""" - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) - def update(self, entity: ConnectConfigEntity): """update db connect info""" - session = self.get_session() + session = self.get_raw_session() try: updated = session.merge(entity) session.commit() @@ -58,7 +44,7 @@ def update(self, entity: ConnectConfigEntity): def delete(self, db_name: int): """ "delete db connect info""" - session = self.get_session() + session = self.get_raw_session() if db_name is None: raise Exception("db_name is None") @@ -70,7 +56,7 @@ def delete(self, db_name: int): def get_by_names(self, db_name: str) -> ConnectConfigEntity: """get db connect info by name""" - session = self.get_session() + session = self.get_raw_session() db_connect = session.query(ConnectConfigEntity) db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name) result = db_connect.first() @@ -99,7 +85,7 @@ def add_url_db( comment: comment """ try: - session = self.get_session() + session = self.get_raw_session() from sqlalchemy import text @@ -144,7 +130,7 @@ def update_db_info( old_db_conf = self.get_db_config(db_name) if old_db_conf: try: - session = self.get_session() + session = self.get_raw_session() if not db_path: update_statement = text( f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'" @@ -164,7 +150,7 @@ def update_db_info( def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): """add file db connect info""" try: - session = self.get_session() + session = self.get_raw_session() insert_statement = text( """ INSERT INTO connect_config( @@ -194,7 +180,7 @@ def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): def get_db_config(self, db_name): """get db config by name""" - session = self.get_session() + session = self.get_raw_session() if db_name: select_statement = text( """ @@ -221,7 +207,7 @@ def get_db_config(self, db_name): def get_db_list(self): """get db list""" - session = self.get_session() + session = self.get_raw_session() result = session.execute(text("SELECT * FROM connect_config")) fields = [field[0] for field in result.cursor.description] @@ -235,7 +221,7 @@ def get_db_list(self): def delete_db(self, db_name): """delete db connect info""" - session = self.get_session() + session = self.get_raw_session() delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""") params = {"db_name": db_name} session.execute(delete_statement, params) diff --git a/dbgpt/datasource/rdbms/base.py b/dbgpt/datasource/rdbms/base.py index 32d7794a2..0570b3771 100644 --- a/dbgpt/datasource/rdbms/base.py +++ b/dbgpt/datasource/rdbms/base.py @@ -270,7 +270,12 @@ def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> st """Format the error message""" return f"Error: {e}" - def __write(self, write_sql): + def _write(self, write_sql: str): + """Run a SQL write command and return the results as a list of tuples. + + Args: + write_sql (str): SQL write command to run + """ print(f"Write[{write_sql}]") db_cache = self._engine.url.database result = self.session.execute(text(write_sql)) @@ -280,16 +285,12 @@ def __write(self, write_sql): print(f"SQL[{write_sql}], result:{result.rowcount}") return result.rowcount - def __query(self, query, fetch: str = "all"): - """ - only for query - Args: - session: - query: - fetch: - - Returns: + def _query(self, query: str, fetch: str = "all"): + """Run a SQL query and return the results as a list of tuples. + Args: + query (str): SQL query to run + fetch (str): fetch type """ print(f"Query[{query}]") if not query: @@ -308,6 +309,10 @@ def __query(self, query, fetch: str = "all"): result.insert(0, field_names) return result + def query_table_schema(self, table_name): + sql = f"select * from {table_name} limit 1" + return self._query(sql) + def query_ex(self, query, fetch: str = "all"): """ only for query @@ -325,7 +330,7 @@ def query_ex(self, query, fetch: str = "all"): if fetch == "all": result = cursor.fetchall() elif fetch == "one": - result = cursor.fetchone()[0] # type: ignore + result = cursor.fetchone() # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") field_names = list(i[0:] for i in cursor.keys()) @@ -342,12 +347,12 @@ def run(self, command: str, fetch: str = "all") -> List: parsed, ttype, sql_type, table_name = self.__sql_parse(command) if ttype == sqlparse.tokens.DML: if sql_type == "SELECT": - return self.__query(command, fetch) + return self._query(command, fetch) else: - self.__write(command) + self._write(command) select_sql = self.convert_sql_write_to_select(command) print(f"write result query:{select_sql}") - return self.__query(select_sql) + return self._query(select_sql) else: print(f"DDL execution determines whether to enable through configuration ") @@ -360,10 +365,11 @@ def run(self, command: str, fetch: str = "all") -> List: result.insert(0, field_names) print("DDL Result:" + str(result)) if not result: - return self.__query(f"SHOW COLUMNS FROM {table_name}") + # return self._query(f"SHOW COLUMNS FROM {table_name}") + return self.get_simple_fields(table_name) return result else: - return self.__query(f"SHOW COLUMNS FROM {table_name}") + return self.get_simple_fields(table_name) def run_to_df(self, command: str, fetch: str = "all"): result_lst = self.run(command, fetch) @@ -451,13 +457,23 @@ def __sql_parse(self, sql): sql = sql.strip() parsed = sqlparse.parse(sql)[0] sql_type = parsed.get_type() - table_name = parsed.get_name() + if sql_type == "CREATE": + table_name = self._extract_table_name_from_ddl(parsed) + else: + table_name = parsed.get_name() first_token = parsed.token_first(skip_ws=True, skip_cm=False) ttype = first_token.ttype print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}") return parsed, ttype, sql_type, table_name + def _extract_table_name_from_ddl(self, parsed): + """Extract table name from CREATE TABLE statement.""" "" + for token in parsed.tokens: + if token.ttype is None and isinstance(token, sqlparse.sql.Identifier): + return token.get_real_name() + return None + def get_indexes(self, table_name): """Get table indexes about specified table.""" session = self._db_sessions() @@ -485,6 +501,10 @@ def get_fields(self, table_name): fields = cursor.fetchall() return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] + def get_simple_fields(self, table_name): + """Get column fields about specified table.""" + return self._query(f"SHOW COLUMNS FROM {table_name}") + def get_charset(self): """Get character_set.""" session = self._db_sessions() diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py index cff76df94..b535cd80f 100644 --- a/dbgpt/datasource/rdbms/conn_sqlite.py +++ b/dbgpt/datasource/rdbms/conn_sqlite.py @@ -56,6 +56,10 @@ def get_fields(self, table_name): print(fields) return [(field[1], field[2], field[3], field[4], field[5]) for field in fields] + def get_simple_fields(self, table_name): + """Get column fields about specified table.""" + return self.get_fields(table_name) + def get_users(self): return [] @@ -88,8 +92,9 @@ def _sync_tables_from_db(self) -> Iterable[str]: self._metadata.reflect(bind=self._engine) return self._all_tables - def _write(self, session, write_sql): + def _write(self, write_sql): print(f"Write[{write_sql}]") + session = self.session result = session.execute(text(write_sql)) session.commit() # TODO Subsequent optimization of dynamically specified database submission loss target problem diff --git a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py index e779efe7a..f8e43d1e1 100644 --- a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py +++ b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py @@ -25,41 +25,41 @@ def test_get_table_info(db): def test_get_table_info_with_table(db): - db.run(db.session, "CREATE TABLE test (id INTEGER);") + db.run("CREATE TABLE test (id INTEGER);") print(db._sync_tables_from_db()) table_info = db.get_table_info() assert "CREATE TABLE test" in table_info def test_run_sql(db): - result = db.run(db.session, "CREATE TABLE test (id INTEGER);") - assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk") + result = db.run("CREATE TABLE test(id INTEGER);") + assert result[0] == ("id", "INTEGER", 0, None, 0) def test_run_no_throw(db): - assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:") + assert db.run_no_throw("this is a error sql").startswith("Error:") def test_get_indexes(db): - db.run(db.session, "CREATE TABLE test (name TEXT);") - db.run(db.session, "CREATE INDEX idx_name ON test(name);") + db.run("CREATE TABLE test (name TEXT);") + db.run("CREATE INDEX idx_name ON test(name);") assert db.get_indexes("test") == [("idx_name", "c")] def test_get_indexes_empty(db): - db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);") assert db.get_indexes("test") == [] def test_get_show_create_table(db): - db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);") assert ( db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)" ) def test_get_fields(db): - db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);") assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)] @@ -72,26 +72,26 @@ def test_get_collation(db): def test_table_simple_info(db): - db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);") assert db.table_simple_info() == ["test(id);"] def test_get_table_info_no_throw(db): - db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);") assert db.get_table_info_no_throw("xxxx_table").startswith("Error:") def test_query_ex(db): - db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") - db.run(db.session, "insert into test(id) values (1)") - db.run(db.session, "insert into test(id) values (2)") - field_names, result = db.query_ex(db.session, "select * from test") + db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run("insert into test(id) values (1)") + db.run("insert into test(id) values (2)") + field_names, result = db.query_ex("select * from test") assert field_names == ["id"] assert result == [(1,), (2,)] - field_names, result = db.query_ex(db.session, "select * from test", fetch="one") + field_names, result = db.query_ex("select * from test", fetch="one") assert field_names == ["id"] - assert result == [(1,)] + assert result == [1] def test_convert_sql_write_to_select(db): @@ -109,7 +109,7 @@ def test_get_users(db): def test_get_table_comments(db): assert db.get_table_comments() == [] - db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);") assert db.get_table_comments() == [ ("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)") ] diff --git a/dbgpt/model/cluster/apiserver/tests/test_api.py b/dbgpt/model/cluster/apiserver/tests/test_api.py index d730a8914..681dcfe40 100644 --- a/dbgpt/model/cluster/apiserver/tests/test_api.py +++ b/dbgpt/model/cluster/apiserver/tests/test_api.py @@ -4,6 +4,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from httpx import AsyncClient, HTTPError +import importlib.metadata as metadata from dbgpt.component import SystemApp from dbgpt.util.openai_utils import chat_completion_stream, chat_completion @@ -190,12 +191,26 @@ async def test_chat_completions_with_openai_lib_async_stream( ) stream_stream_resp = "" - async for stream_resp in await openai.ChatCompletion.acreate( - model=model_name, - messages=[{"role": "user", "content": "Hello! What is your name?"}], - stream=True, - ): + if metadata.version("openai") >= "1.0.0": + from openai import OpenAI + + client = OpenAI( + **{"base_url": "http://test/api/v1", "api_key": client_api_key} + ) + res = await client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + stream=True, + ) + else: + res = openai.ChatCompletion.acreate( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + stream=True, + ) + async for stream_resp in res: stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "") + assert stream_stream_resp == expected_messages diff --git a/dbgpt/storage/cache/llm_cache.py b/dbgpt/storage/cache/llm_cache.py index d6dd298ee..682276349 100644 --- a/dbgpt/storage/cache/llm_cache.py +++ b/dbgpt/storage/cache/llm_cache.py @@ -75,9 +75,8 @@ def __str__(self) -> str: class LLMCacheKey(CacheKey[LLMCacheKeyData]): - def __init__(self, serializer: Serializer = None, **kwargs) -> None: + def __init__(self, **kwargs) -> None: super().__init__() - self._serializer = serializer self.config = LLMCacheKeyData(**kwargs) def __hash__(self) -> int: @@ -96,30 +95,23 @@ def get_hash_bytes(self) -> bytes: def to_dict(self) -> Dict: return asdict(self.config) - def serialize(self) -> bytes: - return self._serializer.serialize(self) - def get_value(self) -> LLMCacheKeyData: return self.config class LLMCacheValue(CacheValue[LLMCacheValueData]): - def __init__(self, serializer: Serializer = None, **kwargs) -> None: + def __init__(self, **kwargs) -> None: super().__init__() - self._serializer = serializer self.value = LLMCacheValueData.from_dict(**kwargs) def to_dict(self) -> Dict: return self.value.to_dict() - def serialize(self) -> bytes: - return self._serializer.serialize(self) - def get_value(self) -> LLMCacheValueData: return self.value def __str__(self) -> str: - return f"vaue: {str(self.value)}" + return f"value: {str(self.value)}" class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]): @@ -146,7 +138,11 @@ async def exists( return await self.get(key, cache_config) is not None def new_key(self, **kwargs) -> LLMCacheKey: - return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs) + key = LLMCacheKey(**kwargs) + key.set_serializer(self._cache_manager.serializer) + return key def new_value(self, **kwargs) -> LLMCacheValue: - return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs) + value = LLMCacheValue(**kwargs) + value.set_serializer(self._cache_manager.serializer) + return value diff --git a/dbgpt/storage/chat_history/chat_history_db.py b/dbgpt/storage/chat_history/chat_history_db.py index fb084d0bd..05002f189 100644 --- a/dbgpt/storage/chat_history/chat_history_db.py +++ b/dbgpt/storage/chat_history/chat_history_db.py @@ -1,17 +1,12 @@ from typing import Optional -from sqlalchemy import Column, Integer, String, Index, Text +from datetime import datetime +from sqlalchemy import Column, Integer, String, Index, Text, DateTime from sqlalchemy import UniqueConstraint -from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.metadata.meta_data import ( - Base, - engine, - session, - META_DATA_DATABASE, -) +from dbgpt.storage.metadata import BaseDao, Model -class ChatHistoryEntity(Base): +class ChatHistoryEntity(Model): __tablename__ = "chat_history" id = Column( Integer, primary_key=True, autoincrement=True, comment="autoincrement id" @@ -22,7 +17,7 @@ class ChatHistoryEntity(Base): } conv_uid = Column( String(255), - unique=False, + unique=True, nullable=False, comment="Conversation record unique id", ) @@ -32,26 +27,48 @@ class ChatHistoryEntity(Base): messages = Column( Text(length=2**31 - 1), nullable=True, comment="Conversation details" ) + message_ids = Column( + Text(length=2**31 - 1), nullable=True, comment="Message ids, split by comma" + ) sys_code = Column(String(128), index=True, nullable=True, comment="System code") - UniqueConstraint("conv_uid", name="uk_conversation") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + Index("idx_q_user", "user_name") Index("idx_q_mode", "chat_mode") Index("idx_q_conv", "summary") -class ChatHistoryDao(BaseDao[ChatHistoryEntity]): - def __init__(self): - super().__init__( - database=META_DATA_DATABASE, - orm_base=Base, - db_engine=engine, - session=session, - ) +class ChatHistoryMessageEntity(Model): + __tablename__ = "chat_history_message" + id = Column( + Integer, primary_key=True, autoincrement=True, comment="autoincrement id" + ) + __table_args__ = { + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_unicode_ci", + } + conv_uid = Column( + String(255), + unique=False, + nullable=False, + comment="Conversation record unique id", + ) + index = Column(Integer, nullable=False, comment="Message index") + round_index = Column(Integer, nullable=False, comment="Message round index") + message_detail = Column( + Text(length=2**31 - 1), nullable=True, comment="Message details, json format" + ) + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + UniqueConstraint("conv_uid", "index", name="uk_conversation_message") + +class ChatHistoryDao(BaseDao[ChatHistoryEntity]): def list_last_20( self, user_name: Optional[str] = None, sys_code: Optional[str] = None ): - session = self.get_session() + session = self.get_raw_session() chat_history = session.query(ChatHistoryEntity) if user_name: chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name) @@ -65,7 +82,7 @@ def list_last_20( return result def update(self, entity: ChatHistoryEntity): - session = self.get_session() + session = self.get_raw_session() try: updated = session.merge(entity) session.commit() @@ -74,7 +91,7 @@ def update(self, entity: ChatHistoryEntity): session.close() def update_message_by_uid(self, message: str, conv_uid: str): - session = self.get_session() + session = self.get_raw_session() try: chat_history = session.query(ChatHistoryEntity) chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) @@ -85,20 +102,12 @@ def update_message_by_uid(self, message: str, conv_uid: str): session.close() def delete(self, conv_uid: int): - session = self.get_session() if conv_uid is None: raise Exception("conv_uid is None") - - chat_history = session.query(ChatHistoryEntity) - chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) - chat_history.delete() - session.commit() - session.close() + with self.session() as session: + chat_history = session.query(ChatHistoryEntity) + chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) + chat_history.delete() def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity: - session = self.get_session() - chat_history = session.query(ChatHistoryEntity) - chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) - result = chat_history.first() - session.close() - return result + return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first() diff --git a/dbgpt/storage/chat_history/storage_adapter.py b/dbgpt/storage/chat_history/storage_adapter.py new file mode 100644 index 000000000..4a7a3b96b --- /dev/null +++ b/dbgpt/storage/chat_history/storage_adapter.py @@ -0,0 +1,116 @@ +from typing import List, Dict, Type +import json +from sqlalchemy.orm import Session +from dbgpt.core.interface.storage import StorageItemAdapter +from dbgpt.core.interface.message import ( + StorageConversation, + ConversationIdentifier, + MessageIdentifier, + MessageStorageItem, + _messages_from_dict, + _conversation_to_dict, + BaseMessage, +) +from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity + + +class DBStorageConversationItemAdapter( + StorageItemAdapter[StorageConversation, ChatHistoryEntity] +): + def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity: + message_ids = ",".join(item.message_ids) + messages = None + if not item.save_message_independent and item.messages: + messages = _conversation_to_dict(item) + return ChatHistoryEntity( + conv_uid=item.conv_uid, + chat_mode=item.chat_mode, + summary=item.summary or item.get_latest_user_message().content, + user_name=item.user_name, + # We not save messages to chat_history table in new design + messages=messages, + message_ids=message_ids, + sys_code=item.sys_code, + ) + + def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation: + message_ids = model.message_ids.split(",") if model.message_ids else [] + old_conversations: List[Dict] = ( + json.loads(model.messages) if model.messages else [] + ) + old_messages = [] + save_message_independent = True + if old_conversations: + # Load old messages from old conversations, in old design, we save messages to chat_history table + old_messages_dict = [] + for old_conversation in old_conversations: + old_messages_dict.extend( + old_conversation["messages"] + if "messages" in old_conversation + else [] + ) + save_message_independent = False + old_messages: List[BaseMessage] = _messages_from_dict(old_messages_dict) + return StorageConversation( + conv_uid=model.conv_uid, + chat_mode=model.chat_mode, + summary=model.summary, + user_name=model.user_name, + message_ids=message_ids, + sys_code=model.sys_code, + save_message_independent=save_message_independent, + messages=old_messages, + ) + + def get_query_for_identifier( + self, + storage_format: Type[ChatHistoryEntity], + resource_id: ConversationIdentifier, + **kwargs, + ): + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + return session.query(ChatHistoryEntity).filter( + ChatHistoryEntity.conv_uid == resource_id.conv_uid + ) + + +class DBMessageStorageItemAdapter( + StorageItemAdapter[MessageStorageItem, ChatHistoryMessageEntity] +): + def to_storage_format(self, item: MessageStorageItem) -> ChatHistoryMessageEntity: + round_index = item.message_detail.get("round_index", 0) + message_detail = json.dumps(item.message_detail, ensure_ascii=False) + return ChatHistoryMessageEntity( + conv_uid=item.conv_uid, + index=item.index, + round_index=round_index, + message_detail=message_detail, + ) + + def from_storage_format( + self, model: ChatHistoryMessageEntity + ) -> MessageStorageItem: + message_detail = ( + json.loads(model.message_detail) if model.message_detail else {} + ) + return MessageStorageItem( + conv_uid=model.conv_uid, + index=model.index, + message_detail=message_detail, + ) + + def get_query_for_identifier( + self, + storage_format: Type[ChatHistoryMessageEntity], + resource_id: MessageIdentifier, + **kwargs, + ): + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + return session.query(ChatHistoryMessageEntity).filter( + ChatHistoryMessageEntity.conv_uid == resource_id.conv_uid, + ChatHistoryMessageEntity.index == resource_id.index, + ) diff --git a/dbgpt/storage/chat_history/store_type/duckdb_history.py b/dbgpt/storage/chat_history/store_type/duckdb_history.py index 7127bb885..2e99fc185 100644 --- a/dbgpt/storage/chat_history/store_type/duckdb_history.py +++ b/dbgpt/storage/chat_history/store_type/duckdb_history.py @@ -87,7 +87,7 @@ def append(self, once_message: OnceConversation) -> None: [ self.chat_seesion_id, once_message.chat_mode, - once_message.get_user_conv().content, + once_message.get_latest_user_message().content, once_message.user_name, once_message.sys_code, json.dumps(conversations, ensure_ascii=False), diff --git a/dbgpt/storage/chat_history/store_type/meta_db_history.py b/dbgpt/storage/chat_history/store_type/meta_db_history.py index a92dc4be5..9bc2d0ceb 100644 --- a/dbgpt/storage/chat_history/store_type/meta_db_history.py +++ b/dbgpt/storage/chat_history/store_type/meta_db_history.py @@ -52,14 +52,14 @@ def append(self, once_message: OnceConversation) -> None: if context: conversations = json.loads(context) else: - chat_history.summary = once_message.get_user_conv().content + chat_history.summary = once_message.get_latest_user_message().content else: chat_history: ChatHistoryEntity = ChatHistoryEntity() chat_history.conv_uid = self.chat_seesion_id chat_history.chat_mode = once_message.chat_mode chat_history.user_name = once_message.user_name chat_history.sys_code = once_message.sys_code - chat_history.summary = once_message.get_user_conv().content + chat_history.summary = once_message.get_latest_user_message().content conversations.append(_conversation_to_dict(once_message)) chat_history.messages = json.dumps(conversations, ensure_ascii=False) diff --git a/dbgpt/storage/chat_history/tests/__init__.py b/dbgpt/storage/chat_history/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/storage/chat_history/tests/test_storage_adapter.py b/dbgpt/storage/chat_history/tests/test_storage_adapter.py new file mode 100644 index 000000000..1802a8fd0 --- /dev/null +++ b/dbgpt/storage/chat_history/tests/test_storage_adapter.py @@ -0,0 +1,219 @@ +import pytest +from typing import List + +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.serialization.json_serialization import JsonSerializer +from dbgpt.core.interface.message import StorageConversation, HumanMessage, AIMessage +from dbgpt.core.interface.storage import QuerySpec +from dbgpt.storage.metadata import db +from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage +from dbgpt.storage.chat_history.chat_history_db import ( + ChatHistoryEntity, + ChatHistoryMessageEntity, +) +from dbgpt.storage.chat_history.storage_adapter import ( + DBStorageConversationItemAdapter, + DBMessageStorageItemAdapter, +) + + +@pytest.fixture +def serializer(): + return JsonSerializer() + + +@pytest.fixture +def db_url(): + """Use in-memory SQLite database for testing""" + return "sqlite:///:memory:" + # return "sqlite:///test.db" + + +@pytest.fixture +def db_manager(db_url): + db.init_db(db_url) + db.create_all() + return db + + +@pytest.fixture +def storage_adapter(): + return DBStorageConversationItemAdapter() + + +@pytest.fixture +def storage_message_adapter(): + return DBMessageStorageItemAdapter() + + +@pytest.fixture +def conv_storage(db_manager, serializer, storage_adapter): + storage = SQLAlchemyStorage( + db_manager, + ChatHistoryEntity, + storage_adapter, + serializer, + ) + return storage + + +@pytest.fixture +def message_storage(db_manager, serializer, storage_message_adapter): + storage = SQLAlchemyStorage( + db_manager, + ChatHistoryMessageEntity, + storage_message_adapter, + serializer, + ) + return storage + + +@pytest.fixture +def conversation(conv_storage, message_storage): + return StorageConversation( + "conv1", + chat_mode="chat_normal", + user_name="user1", + conv_storage=conv_storage, + message_storage=message_storage, + ) + + +@pytest.fixture +def four_round_conversation(conv_storage, message_storage): + conversation = StorageConversation( + "conv1", + chat_mode="chat_normal", + user_name="user1", + conv_storage=conv_storage, + message_storage=message_storage, + ) + conversation.start_new_round() + conversation.add_user_message("hello, this is first round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is second round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is third round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is fourth round") + conversation.add_ai_message("hi") + conversation.end_current_round() + return conversation + + +@pytest.fixture +def conversation_list(request, conv_storage, message_storage): + params = request.param if hasattr(request, "param") else {} + conv_count = params.get("conv_count", 4) + result = [] + for i in range(conv_count): + conversation = StorageConversation( + f"conv{i}", + chat_mode="chat_normal", + user_name="user1", + conv_storage=conv_storage, + message_storage=message_storage, + ) + conversation.start_new_round() + conversation.add_user_message("hello, this is first round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is second round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is third round") + conversation.add_ai_message("hi") + conversation.end_current_round() + conversation.start_new_round() + conversation.add_user_message("hello, this is fourth round") + conversation.add_ai_message("hi") + conversation.end_current_round() + result.append(conversation) + return result + + +def test_save_and_load( + conversation: StorageConversation, conv_storage, message_storage +): + conversation.start_new_round() + conversation.add_user_message("hello") + conversation.add_ai_message("hi") + conversation.end_current_round() + + saved_conversation = StorageConversation( + conv_uid=conversation.conv_uid, + conv_storage=conv_storage, + message_storage=message_storage, + ) + assert saved_conversation.conv_uid == conversation.conv_uid + assert len(saved_conversation.messages) == 2 + assert isinstance(saved_conversation.messages[0], HumanMessage) + assert isinstance(saved_conversation.messages[1], AIMessage) + assert saved_conversation.messages[0].content == "hello" + assert saved_conversation.messages[0].round_index == 1 + assert saved_conversation.messages[1].content == "hi" + assert saved_conversation.messages[1].round_index == 1 + + +def test_query_message( + conversation: StorageConversation, conv_storage, message_storage +): + conversation.start_new_round() + conversation.add_user_message("hello") + conversation.add_ai_message("hi") + conversation.end_current_round() + + saved_conversation = StorageConversation( + conv_uid=conversation.conv_uid, + conv_storage=conv_storage, + message_storage=message_storage, + ) + assert saved_conversation.conv_uid == conversation.conv_uid + assert len(saved_conversation.messages) == 2 + + query_spec = QuerySpec(conditions={"conv_uid": conversation.conv_uid}) + results = conversation.conv_storage.query(query_spec, StorageConversation) + assert len(results) == 1 + + +def test_complex_query( + conversation_list: List[StorageConversation], conv_storage, message_storage +): + query_spec = QuerySpec(conditions={"user_name": "user1"}) + results = conv_storage.query(query_spec, StorageConversation) + assert len(results) == len(conversation_list) + for i, result in enumerate(results): + assert result.user_name == "user1" + assert result.conv_uid == f"conv{i}" + saved_conversation = StorageConversation( + conv_uid=result.conv_uid, + conv_storage=conv_storage, + message_storage=message_storage, + ) + assert len(saved_conversation.messages) == 8 + assert isinstance(saved_conversation.messages[0], HumanMessage) + assert isinstance(saved_conversation.messages[1], AIMessage) + assert saved_conversation.messages[0].content == "hello, this is first round" + assert saved_conversation.messages[1].content == "hi" + + +def test_query_with_page( + conversation_list: List[StorageConversation], conv_storage, message_storage +): + query_spec = QuerySpec(conditions={"user_name": "user1"}) + page_result: PaginationResult = conv_storage.paginate_query( + page=1, page_size=2, cls=StorageConversation, spec=query_spec + ) + assert page_result.total_count == len(conversation_list) + assert page_result.total_pages == 2 + assert page_result.page_size == 2 + assert len(page_result.items) == 2 + assert page_result.items[0].conv_uid == "conv0" diff --git a/dbgpt/storage/metadata/__init__.py b/dbgpt/storage/metadata/__init__.py index fc4b6447a..63e58aefe 100644 --- a/dbgpt/storage/metadata/__init__.py +++ b/dbgpt/storage/metadata/__init__.py @@ -1 +1,17 @@ +from dbgpt.storage.metadata.db_manager import ( + db, + Model, + DatabaseManager, + create_model, + BaseModel, +) from dbgpt.storage.metadata._base_dao import BaseDao + +__ALL__ = [ + "db", + "Model", + "DatabaseManager", + "create_model", + "BaseModel", + "BaseDao", +] diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 6aa7c7c5b..93ff289b4 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -1,25 +1,72 @@ -from typing import TypeVar, Generic, Any -from sqlalchemy.orm import sessionmaker +from contextlib import contextmanager +from typing import TypeVar, Generic, Any, Optional +from sqlalchemy.orm.session import Session T = TypeVar("T") +from .db_manager import db, DatabaseManager + class BaseDao(Generic[T]): + """The base class for all DAOs. + + Examples: + .. code-block:: python + class UserDao(BaseDao[User]): + def get_user_by_name(self, name: str) -> User: + with self.session() as session: + return session.query(User).filter(User.name == name).first() + + def get_user_by_id(self, id: int) -> User: + with self.session() as session: + return User.get(id) + + def create_user(self, name: str) -> User: + return User.create(**{"name": name}) + Args: + db_manager (DatabaseManager, optional): The database manager. Defaults to None. + If None, the default database manager(db) will be used. + """ + def __init__( self, - orm_base=None, - database: str = None, - db_engine: Any = None, - session: Any = None, + db_manager: Optional[DatabaseManager] = None, ) -> None: - """BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist""" - self._orm_base = orm_base - self._database = database + self._db_manager = db_manager or db + + def get_raw_session(self) -> Session: + """Get a raw session object. + + Your should commit or rollback the session manually. + We suggest you use :meth:`session` instead. + + + Example: + .. code-block:: python + user = User(name="Edward Snowden") + session = self.get_raw_session() + session.add(user) + session.commit() + session.close() + """ + return self._db_manager._session() + + @contextmanager + def session(self) -> Session: + """Provide a transactional scope around a series of operations. + + If raise an exception, the session will be roll back automatically, otherwise it will be committed. + + Example: + .. code-block:: python + with self.session() as session: + session.query(User).filter(User.name == 'Edward Snowden').first() - self._db_engine = db_engine - self._session = session + Returns: + Session: A session object. - def get_session(self): - Session = sessionmaker(autocommit=False, autoflush=False, bind=self._db_engine) - session = Session() - return session + Raises: + Exception: Any exception will be raised. + """ + with self._db_manager.session() as session: + yield session diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py new file mode 100644 index 000000000..782fcf8ab --- /dev/null +++ b/dbgpt/storage/metadata/db_manager.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import abc +from contextlib import contextmanager +from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List +import logging +from sqlalchemy import create_engine, URL, Engine +from sqlalchemy import orm, inspect, MetaData +from sqlalchemy.orm import ( + scoped_session, + sessionmaker, + Session, + declarative_base, + DeclarativeMeta, +) +from sqlalchemy.orm.session import _PKIdentityArgument +from sqlalchemy.orm.exc import UnmappedClassError + +from sqlalchemy.pool import QueuePool +from dbgpt.util.string_utils import _to_str +from dbgpt.util.pagination_utils import PaginationResult + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound="BaseModel") + + +class _QueryObject: + """The query object.""" + + def __init__(self, db_manager: "DatabaseManager"): + self._db_manager = db_manager + + def __get__(self, obj, type): + try: + mapper = orm.class_mapper(type) + if mapper: + return type.query_class(mapper, session=self._db_manager._session()) + except UnmappedClassError: + return None + + +class BaseQuery(orm.Query): + def paginate_query( + self, page: Optional[int] = 1, per_page: Optional[int] = 20 + ) -> PaginationResult: + """Paginate the query. + + Example: + .. code-block:: python + from dbgpt.storage.metadata import db, Model + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + fullname = Column(String(50)) + + with db.session() as session: + pagination = session.query(User).paginate_query(page=1, page_size=10) + print(pagination) + + # Or you can use the query object + with db.session() as session: + pagination = User.query.paginate_query(page=1, page_size=10) + print(pagination) + + Args: + page (Optional[int], optional): The page number. Defaults to 1. + per_page (Optional[int], optional): The number of items per page. Defaults to 20. + Returns: + PaginationResult: The pagination result. + """ + if page < 1: + raise ValueError("Page must be greater than 0") + if per_page < 0: + raise ValueError("Per page must be greater than 0") + items = self.limit(per_page).offset((page - 1) * per_page).all() + total = self.order_by(None).count() + total_pages = (total - 1) // per_page + 1 + return PaginationResult( + items=items, + total_count=total, + total_pages=total_pages, + page=page, + page_size=per_page, + ) + + +class _Model: + """Base class for SQLAlchemy declarative base model. + + With this class, we can use the query object to query the database. + + Examples: + .. code-block:: python + from dbgpt.storage.metadata import db, Model + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + fullname = Column(String(50)) + + with db.session() as session: + # User is an instance of _Model, and we can use the query object to query the database. + User.query.filter(User.name == "test").all() + """ + + query_class = None + query: Optional[BaseQuery] = None + + def __repr__(self): + identity = inspect(self).identity + if identity is None: + pk = "(transient {0})".format(id(self)) + else: + pk = ", ".join(_to_str(value) for value in identity) + return "<{0} {1}>".format(type(self).__name__, pk) + + +class DatabaseManager: + """The database manager. + + Examples: + .. code-block:: python + from urllib.parse import quote_plus as urlquote, quote + from dbgpt.storage.metadata import DatabaseManager, create_model + db = DatabaseManager() + # Use sqlite with memory storage. + url = f"sqlite:///:memory:" + engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True} + db.init_db(url, engine_args=engine_args) + + Model = create_model(db) + + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + fullname = Column(String(50)) + + with db.session() as session: + session.add(User(name="test", fullname="test")) + # db will commit the session automatically default. + # session.commit() + print(User.query.filter(User.name == "test").all()) + + + # Use CURDMixin APIs to create, update, delete, query the database. + with db.session() as session: + User.create(**{"name": "test1", "fullname": "test1"}) + User.create(**{"name": "test2", "fullname": "test1"}) + users = User.all() + print(users) + user = users[0] + user.update(**{"name": "test1_1111"}) + user2 = users[1] + # Update user2 by save + user2.name = "test2_1111" + user2.save() + # Delete user2 + user2.delete() + """ + + Query = BaseQuery + + def __init__(self): + self._db_url = None + self._base: DeclarativeMeta = self._make_declarative_base(_Model) + self._engine: Optional[Engine] = None + self._session: Optional[scoped_session] = None + + @property + def Model(self) -> _Model: + """Get the declarative base.""" + return self._base + + @property + def metadata(self) -> MetaData: + """Get the metadata.""" + return self.Model.metadata + + @property + def engine(self): + """Get the engine.""" "" + return self._engine + + @contextmanager + def session(self) -> Session: + """Get the session with context manager. + + If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically. + + Example: + >>> with db.session() as session: + >>> session.query(...) + + Returns: + Session: The session. + + Raises: + RuntimeError: The database manager is not initialized. + Exception: Any exception. + """ + if not self._session: + raise RuntimeError("The database manager is not initialized.") + session = self._session() + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() + + def _make_declarative_base( + self, model: Union[Type[DeclarativeMeta], Type[_Model]] + ) -> DeclarativeMeta: + """Make the declarative base. + + Args: + base (DeclarativeMeta): The base class. + + Returns: + DeclarativeMeta: The declarative base. + """ + if not isinstance(model, DeclarativeMeta): + model = declarative_base(cls=model, name="Model") + if not getattr(model, "query_class", None): + model.query_class = self.Query + model.query = _QueryObject(self) + return model + + def init_db( + self, + db_url: Union[str, URL], + engine_args: Optional[Dict] = None, + base: Optional[DeclarativeMeta] = None, + query_class=BaseQuery, + ): + """Initialize the database manager. + + Args: + db_url (Union[str, URL]): The database url. + engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. + base (Optional[DeclarativeMeta]): The base class. Defaults to None. + query_class (BaseQuery, optional): The query class. Defaults to BaseQuery. + """ + self._db_url = db_url + if query_class is not None: + self.Query = query_class + if base is not None: + self._base = base + if not hasattr(base, "query"): + base.query = _QueryObject(self) + if not getattr(base, "query_class", None): + base.query_class = self.Query + self._engine = create_engine(db_url, **(engine_args or {})) + session_factory = sessionmaker(bind=self._engine) + self._session = scoped_session(session_factory) + self._base.metadata.bind = self._engine + + def init_default_db( + self, + sqlite_path: str, + engine_args: Optional[Dict] = None, + base: Optional[DeclarativeMeta] = None, + ): + """Initialize the database manager with default config. + + Examples: + >>> db.init_default_db(sqlite_path) + >>> with db.session() as session: + >>> session.query(...) + + Args: + sqlite_path (str): The sqlite path. + engine_args (Optional[Dict], optional): The engine arguments. + Defaults to None, if None, we will use connection pool. + base (Optional[DeclarativeMeta]): The base class. Defaults to None. + """ + if not engine_args: + engine_args = {} + # Pool class + engine_args["poolclass"] = QueuePool + # The number of connections to keep open inside the connection pool. + engine_args["pool_size"] = 10 + # The maximum overflow size of the pool when the number of connections be used in the pool is exceeded( + # pool_size). + engine_args["max_overflow"] = 20 + # The number of seconds to wait before giving up on getting a connection from the pool. + engine_args["pool_timeout"] = 30 + # Recycle the connection if it has been idle for this many seconds. + engine_args["pool_recycle"] = 3600 + # Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout. + engine_args["pool_pre_ping"] = True + + self.init_db(f"sqlite:///{sqlite_path}", engine_args, base) + + def create_all(self): + self.Model.metadata.create_all(self._engine) + + +db = DatabaseManager() +"""The global database manager. + +Examples: + >>> from dbgpt.storage.metadata import db + >>> sqlite_path = "/tmp/dbgpt.db" + >>> db.init_default_db(sqlite_path) + >>> with db.session() as session: + >>> session.query(...) + + >>> from dbgpt.storage.metadata import db, Model + >>> from urllib.parse import quote_plus as urlquote, quote + >>> db_name = "dbgpt" + >>> db_host = "localhost" + >>> db_port = 3306 + >>> user = "root" + >>> password = "123456" + >>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}" + >>> engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True} + >>> db.init_db(url, engine_args=engine_args) + >>> class User(Model): + >>> __tablename__ = "user" + >>> id = Column(Integer, primary_key=True) + >>> name = Column(String(50)) + >>> fullname = Column(String(50)) + >>> with db.session() as session: + >>> session.add(User(name="test", fullname="test")) + >>> session.commit() +""" + + +class BaseCRUDMixin(Generic[T]): + """The base CRUD mixin.""" + + __abstract__ = True + + @classmethod + def create(cls: Type[T], **kwargs) -> T: + instance = cls(**kwargs) + return instance.save() + + @classmethod + def all(cls: Type[T]) -> List[T]: + return cls.query.all() + + @classmethod + def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]: + """Get a record by its primary key identifier.""" + + def update(self: T, commit: Optional[bool] = True, **kwargs) -> T: + """Update specific fields of a record.""" + for attr, value in kwargs.items(): + setattr(self, attr, value) + return commit and self.save() or self + + @abc.abstractmethod + def save(self: T, commit: Optional[bool] = True) -> T: + """Save the record.""" + + @abc.abstractmethod + def delete(self: T, commit: Optional[bool] = True) -> None: + """Remove the record from the database.""" + + +class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]): + + """The base model class that includes CRUD convenience methods.""" + + __abstract__ = True + + +def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]: + class CRUDMixin(BaseCRUDMixin[T], Generic[T]): + """Mixin that adds convenience methods for CRUD (create, read, update, delete)""" + + @classmethod + def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]: + """Get a record by its primary key identifier.""" + return db_manager._session().get(cls, ident) + + def save(self: T, commit: Optional[bool] = True) -> T: + """Save the record.""" + session = db_manager._session() + session.add(self) + if commit: + session.commit() + return self + + def delete(self: T, commit: Optional[bool] = True) -> None: + """Remove the record from the database.""" + session = db_manager._session() + session.delete(self) + return commit and session.commit() + + class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): + """Base model class that includes CRUD convenience methods.""" + + __abstract__ = True + + return _NewModel + + +Model = create_model(db) + + +def initialize_db( + db_url: Union[str, URL], + db_name: str, + engine_args: Optional[Dict] = None, + base: Optional[DeclarativeMeta] = None, + try_to_create_db: Optional[bool] = False, +) -> DatabaseManager: + """Initialize the database manager. + + Args: + db_url (Union[str, URL]): The database url. + db_name (str): The database name. + engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. + base (Optional[DeclarativeMeta]): The base class. Defaults to None. + try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False. + Returns: + DatabaseManager: The database manager. + """ + db.init_db(db_url, engine_args, base) + if try_to_create_db: + try: + db.create_all() + except Exception as e: + logger.error(f"Failed to create database {db_name}: {e}") + return db diff --git a/dbgpt/storage/metadata/db_storage.py b/dbgpt/storage/metadata/db_storage.py new file mode 100644 index 000000000..d85a1578d --- /dev/null +++ b/dbgpt/storage/metadata/db_storage.py @@ -0,0 +1,128 @@ +from contextlib import contextmanager + +from typing import Type, List, Optional, Union, Dict +from dbgpt.core import Serializer +from dbgpt.core.interface.storage import ( + StorageInterface, + QuerySpec, + ResourceIdentifier, + StorageItemAdapter, + T, +) +from sqlalchemy import URL +from sqlalchemy.orm import Session, DeclarativeMeta + +from .db_manager import BaseModel, DatabaseManager, BaseQuery + + +def _copy_public_properties(src: BaseModel, dest: BaseModel): + """Simple copy public properties from src to dest""" + for column in src.__table__.columns: + if column.name != "id": + setattr(dest, column.name, getattr(src, column.name)) + + +class SQLAlchemyStorage(StorageInterface[T, BaseModel]): + def __init__( + self, + db_url_or_db: Union[str, URL, DatabaseManager], + model_class: Type[BaseModel], + adapter: StorageItemAdapter[T, BaseModel], + serializer: Optional[Serializer] = None, + engine_args: Optional[Dict] = None, + base: Optional[DeclarativeMeta] = None, + query_class=BaseQuery, + ): + super().__init__(serializer=serializer, adapter=adapter) + if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL): + db_manager = DatabaseManager() + db_manager.init_db(db_url_or_db, engine_args, base, query_class) + self.db_manager = db_manager + elif isinstance(db_url_or_db, DatabaseManager): + self.db_manager = db_url_or_db + else: + raise ValueError( + f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}" + ) + self._model_class = model_class + + @contextmanager + def session(self) -> Session: + with self.db_manager.session() as session: + yield session + + def save(self, data: T) -> None: + with self.session() as session: + model_instance = self.adapter.to_storage_format(data) + session.add(model_instance) + + def update(self, data: T) -> None: + with self.session() as session: + model_instance = self.adapter.to_storage_format(data) + session.merge(model_instance) + + def save_or_update(self, data: T) -> None: + with self.session() as session: + query = self.adapter.get_query_for_identifier( + self._model_class, data.identifier, session=session + ) + model_instance = query.with_session(session).first() + if model_instance: + new_instance = self.adapter.to_storage_format(data) + _copy_public_properties(new_instance, model_instance) + session.merge(model_instance) + return + self.save(data) + + def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]: + with self.session() as session: + query = self.adapter.get_query_for_identifier( + self._model_class, resource_id, session=session + ) + model_instance = query.with_session(session).first() + if model_instance: + return self.adapter.from_storage_format(model_instance) + return None + + def delete(self, resource_id: ResourceIdentifier) -> None: + with self.session() as session: + query = self.adapter.get_query_for_identifier( + self._model_class, resource_id, session=session + ) + model_instance = query.with_session(session).first() + if model_instance: + session.delete(model_instance) + + def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: + """Query data from the storage. + + Args: + spec (QuerySpec): The query specification + cls (Type[T]): The type of the data + """ + with self.session() as session: + query = session.query(self._model_class) + for key, value in spec.conditions.items(): + query = query.filter(getattr(self._model_class, key) == value) + if spec.limit is not None: + query = query.limit(spec.limit) + if spec.offset is not None: + query = query.offset(spec.offset) + model_instances = query.all() + return [ + self.adapter.from_storage_format(instance) + for instance in model_instances + ] + + def count(self, spec: QuerySpec, cls: Type[T]) -> int: + """Count the number of data in the storage. + + Args: + spec (QuerySpec): The query specification + cls (Type[T]): The type of the data + """ + with self.session() as session: + query = session.query(self._model_class) + for key, value in spec.conditions.items(): + query = query.filter(getattr(self._model_class, key) == value) + return query.count() diff --git a/dbgpt/storage/metadata/meta_data.py b/dbgpt/storage/metadata/meta_data.py deleted file mode 100644 index a63f0ab36..000000000 --- a/dbgpt/storage/metadata/meta_data.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import sqlite3 -import logging - -from sqlalchemy import create_engine, DDL -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.declarative import declarative_base - -from alembic import command -from alembic.config import Config as AlembicConfig -from urllib.parse import quote -from dbgpt._private.config import Config -from dbgpt.configs.model_config import PILOT_PATH -from urllib.parse import quote_plus as urlquote - - -logger = logging.getLogger(__name__) -# DB-GPT metadata database config, now support mysql and sqlite -CFG = Config() -default_db_path = os.path.join(PILOT_PATH, "meta_data") - -os.makedirs(default_db_path, exist_ok=True) - -# Meta Info -META_DATA_DATABASE = CFG.LOCAL_DB_NAME -db_name = META_DATA_DATABASE -db_path = default_db_path + f"/{db_name}.db" -connection = sqlite3.connect(db_path) - - -if CFG.LOCAL_DB_TYPE == "mysql": - engine_temp = create_engine( - f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}" - ) - # check and auto create mysqldatabase - try: - # try to connect - with engine_temp.connect() as conn: - # TODO We should consider that the production environment does not have permission to execute the DDL - conn.execute(DDL(f"CREATE DATABASE IF NOT EXISTS {db_name}")) - print(f"Already connect '{db_name}'") - - except OperationalError as e: - # if connect failed, create dbgpt database - logger.error(f"{db_name} not connect success!") - - engine = create_engine( - f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}" - ) -else: - engine = create_engine(f"sqlite:///{db_path}") - - -Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) -session = Session() - -Base = declarative_base() - -# Base.metadata.create_all() - -alembic_ini_path = default_db_path + "/alembic.ini" -alembic_cfg = AlembicConfig(alembic_ini_path) - -alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url)) - -os.makedirs(default_db_path + "/alembic", exist_ok=True) -os.makedirs(default_db_path + "/alembic/versions", exist_ok=True) - -alembic_cfg.set_main_option("script_location", default_db_path + "/alembic") - -alembic_cfg.attributes["target_metadata"] = Base.metadata -alembic_cfg.attributes["session"] = session - - -def ddl_init_and_upgrade(disable_alembic_upgrade: bool): - """Initialize and upgrade database metadata - - Args: - disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata - """ - if disable_alembic_upgrade: - logger.info( - "disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic" - ) - return - - with engine.connect() as connection: - alembic_cfg.attributes["connection"] = connection - heads = command.heads(alembic_cfg) - print("heads:" + str(heads)) - - command.revision(alembic_cfg, "dbgpt ddl upate", True) - command.upgrade(alembic_cfg, "head") diff --git a/dbgpt/storage/metadata/tests/__init__.py b/dbgpt/storage/metadata/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/storage/metadata/tests/test_db_manager.py b/dbgpt/storage/metadata/tests/test_db_manager.py new file mode 100644 index 000000000..645f6271a --- /dev/null +++ b/dbgpt/storage/metadata/tests/test_db_manager.py @@ -0,0 +1,129 @@ +from __future__ import annotations +import pytest +from typing import Type +from dbgpt.storage.metadata.db_manager import ( + DatabaseManager, + PaginationResult, + create_model, + BaseModel, +) +from sqlalchemy import Column, Integer, String + + +@pytest.fixture +def db(): + db = DatabaseManager() + db.init_db("sqlite:///:memory:") + return db + + +@pytest.fixture +def Model(db): + return create_model(db) + + +def test_database_initialization(db: DatabaseManager, Model: Type[BaseModel]): + assert db.engine is not None + assert db.session is not None + + with db.session() as session: + assert session is not None + + +def test_model_creation(db: DatabaseManager, Model: Type[BaseModel]): + assert db.metadata.tables == {} + + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + db.create_all() + assert list(db.metadata.tables.keys())[0] == "user" + + +def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]): + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + db.create_all() + + # Create + with db.session() as session: + user = User.create(name="John Doe") + session.add(user) + session.commit() + + # Read + with db.session() as session: + user = session.query(User).filter_by(name="John Doe").first() + assert user is not None + + # Update + with db.session() as session: + user = session.query(User).filter_by(name="John Doe").first() + user.update(name="Jane Doe") + + # Delete + with db.session() as session: + user = session.query(User).filter_by(name="Jane Doe").first() + user.delete() + + +def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]): + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + db.create_all() + + # Create + user = User.create(name="John Doe") + assert User.get(user.id) is not None + users = User.all() + assert len(users) == 1 + + # Update + user.update(name="Bob Doe") + assert User.get(user.id).name == "Bob Doe" + + user = User.get(user.id) + user.delete() + assert User.get(user.id) is None + + +def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]): + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + db.create_all() + + # 添加数据 + with db.session() as session: + for i in range(30): + user = User(name=f"User {i}") + session.add(user) + session.commit() + + users_page_1 = User.query.paginate_query(page=1, per_page=10) + assert len(users_page_1.items) == 10 + assert users_page_1.total_pages == 3 + + +def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]): + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + db.create_all() + + with pytest.raises(ValueError): + User.query.paginate_query(page=0, per_page=10) + with pytest.raises(ValueError): + User.query.paginate_query(page=1, per_page=-1) diff --git a/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py b/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py new file mode 100644 index 000000000..fcae83215 --- /dev/null +++ b/dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py @@ -0,0 +1,173 @@ +from typing import Dict, Type +from sqlalchemy.orm import declarative_base, Session +from sqlalchemy import Column, Integer, String + +import pytest + +from dbgpt.core.interface.storage import ( + StorageItem, + ResourceIdentifier, + StorageItemAdapter, + QuerySpec, +) +from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + +from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier +from dbgpt.util.serialization.json_serialization import JsonSerializer + + +Base = declarative_base() + + +class MockModel(Base): + """The SQLAlchemy model for the mock data.""" + + __tablename__ = "mock_data" + id = Column(Integer, primary_key=True) + data = Column(String) + + +class MockStorageItem(StorageItem): + """The mock storage item.""" + + def merge(self, other: "StorageItem") -> None: + if not isinstance(other, MockStorageItem): + raise ValueError("other must be a MockStorageItem") + self.data = other.data + + def __init__(self, identifier: ResourceIdentifier, data: str): + self._identifier = identifier + self.data = data + + @property + def identifier(self) -> ResourceIdentifier: + return self._identifier + + def to_dict(self) -> Dict: + return {"identifier": self._identifier, "data": self.data} + + def serialize(self) -> bytes: + return str(self.data).encode() + + +class MockStorageItemAdapter(StorageItemAdapter[MockStorageItem, MockModel]): + """The adapter for the mock storage item.""" + + def to_storage_format(self, item: MockStorageItem) -> MockModel: + return MockModel(id=int(item.identifier.str_identifier), data=item.data) + + def from_storage_format(self, model: MockModel) -> MockStorageItem: + return MockStorageItem(MockResourceIdentifier(str(model.id)), model.data) + + def get_query_for_identifier( + self, + storage_format: Type[MockModel], + resource_id: ResourceIdentifier, + **kwargs, + ): + session: Session = kwargs.get("session") + if session is None: + raise ValueError("session is required for this adapter") + return session.query(storage_format).filter( + storage_format.id == int(resource_id.str_identifier) + ) + + +@pytest.fixture +def serializer(): + return JsonSerializer() + + +@pytest.fixture +def db_url(): + """Use in-memory SQLite database for testing""" + return "sqlite:///:memory:" + + +@pytest.fixture +def sqlalchemy_storage(db_url, serializer): + adapter = MockStorageItemAdapter() + storage = SQLAlchemyStorage(db_url, MockModel, adapter, serializer, base=Base) + Base.metadata.create_all(storage.db_manager.engine) + return storage + + +def test_save_and_load(sqlalchemy_storage): + item = MockStorageItem(MockResourceIdentifier("1"), "test_data") + + sqlalchemy_storage.save(item) + + loaded_item = sqlalchemy_storage.load(MockResourceIdentifier("1"), MockStorageItem) + assert loaded_item.data == "test_data" + + +def test_delete(sqlalchemy_storage): + resource_id = MockResourceIdentifier("1") + + sqlalchemy_storage.delete(resource_id) + # Make sure the item is deleted + assert sqlalchemy_storage.load(resource_id, MockStorageItem) is None + + +def test_query_with_various_conditions(sqlalchemy_storage): + # Add multiple items for testing + for i in range(5): + item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}") + sqlalchemy_storage.save(item) + + # Test query with single condition + query_spec = QuerySpec(conditions={"data": "test_data_2"}) + results = sqlalchemy_storage.query(query_spec, MockStorageItem) + assert len(results) == 1 + assert results[0].data == "test_data_2" + + # Test not existing condition + query_spec = QuerySpec(conditions={"data": "nonexistent"}) + results = sqlalchemy_storage.query(query_spec, MockStorageItem) + assert len(results) == 0 + + # Test query with multiple conditions + query_spec = QuerySpec(conditions={"data": "test_data_2", "id": "2"}) + results = sqlalchemy_storage.query(query_spec, MockStorageItem) + assert len(results) == 1 + + +def test_query_nonexistent_item(sqlalchemy_storage): + query_spec = QuerySpec(conditions={"data": "nonexistent"}) + results = sqlalchemy_storage.query(query_spec, MockStorageItem) + assert len(results) == 0 + + +def test_count_items(sqlalchemy_storage): + for i in range(5): + item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}") + sqlalchemy_storage.save(item) + + # Test count without conditions + query_spec = QuerySpec(conditions={}) + total_count = sqlalchemy_storage.count(query_spec, MockStorageItem) + assert total_count == 5 + + # Test count with conditions + query_spec = QuerySpec(conditions={"data": "test_data_2"}) + total_count = sqlalchemy_storage.count(query_spec, MockStorageItem) + assert total_count == 1 + + +def test_paginate_query(sqlalchemy_storage): + for i in range(10): + item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}") + sqlalchemy_storage.save(item) + + page_size = 3 + page_number = 2 + + query_spec = QuerySpec(conditions={}) + page_result = sqlalchemy_storage.paginate_query( + page_number, page_size, MockStorageItem, query_spec + ) + + assert len(page_result.items) == page_size + assert page_result.page == page_number + assert page_result.total_pages == 4 + assert page_result.total_count == 10 diff --git a/dbgpt/util/_db_migration_utils.py b/dbgpt/util/_db_migration_utils.py new file mode 100644 index 000000000..2d0212467 --- /dev/null +++ b/dbgpt/util/_db_migration_utils.py @@ -0,0 +1,219 @@ +from typing import Optional +import os +import logging +from sqlalchemy import Engine, text +from sqlalchemy.orm import Session, DeclarativeMeta +from alembic import command +from alembic.util.exc import CommandError +from alembic.config import Config as AlembicConfig + + +logger = logging.getLogger(__name__) + + +def create_alembic_config( + alembic_root_path: str, + engine: Engine, + base: DeclarativeMeta, + session: Session, + alembic_ini_path: Optional[str] = None, + script_location: Optional[str] = None, +) -> AlembicConfig: + """Create alembic config. + + Args: + alembic_root_path: alembic root path + engine: sqlalchemy engine + base: sqlalchemy base + session: sqlalchemy session + alembic_ini_path (Optional[str]): alembic ini path + script_location (Optional[str]): alembic script location + + Returns: + alembic config + """ + alembic_ini_path = alembic_ini_path or os.path.join( + alembic_root_path, "alembic.ini" + ) + alembic_cfg = AlembicConfig(alembic_ini_path) + alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url)) + script_location = script_location or os.path.join(alembic_root_path, "alembic") + versions_dir = os.path.join(script_location, "versions") + + os.makedirs(script_location, exist_ok=True) + os.makedirs(versions_dir, exist_ok=True) + + alembic_cfg.set_main_option("script_location", script_location) + + alembic_cfg.attributes["target_metadata"] = base.metadata + alembic_cfg.attributes["session"] = session + return alembic_cfg + + +def create_migration_script( + alembic_cfg: AlembicConfig, engine: Engine, message: str = "New migration" +) -> None: + """Create migration script. + + Args: + alembic_cfg: alembic config + engine: sqlalchemy engine + message: migration message + + """ + with engine.connect() as connection: + alembic_cfg.attributes["connection"] = connection + command.revision(alembic_cfg, message, autogenerate=True) + + +def upgrade_database( + alembic_cfg: AlembicConfig, engine: Engine, target_version: str = "head" +) -> None: + """Upgrade database to target version. + + Args: + alembic_cfg: alembic config + engine: sqlalchemy engine + target_version: target version, default is head(latest version) + """ + with engine.connect() as connection: + alembic_cfg.attributes["connection"] = connection + # Will create tables if not exists + command.upgrade(alembic_cfg, target_version) + + +def downgrade_database( + alembic_cfg: AlembicConfig, engine: Engine, revision: str = "-1" +): + """Downgrade the database by one revision. + + Args: + alembic_cfg: Alembic configuration object. + engine: SQLAlchemy engine instance. + revision: Revision identifier, default is "-1" which means one revision back. + """ + with engine.connect() as connection: + alembic_cfg.attributes["connection"] = connection + command.downgrade(alembic_cfg, revision) + + +def clean_alembic_migration(alembic_cfg: AlembicConfig, engine: Engine) -> None: + """Clean Alembic migration scripts and history. + + Args: + alembic_cfg: Alembic config object + engine: SQLAlchemy engine instance + + """ + import shutil + + # Get migration script location + script_location = alembic_cfg.get_main_option("script_location") + print(f"Delete migration script location: {script_location}") + + # Delete all migration script files + for file in os.listdir(script_location): + if file.startswith("versions"): + filepath = os.path.join(script_location, file) + print(f"Delete migration script file: {filepath}") + if os.path.isfile(filepath): + os.remove(filepath) + else: + shutil.rmtree(filepath, ignore_errors=True) + + # Delete Alembic version table if exists + version_table = alembic_cfg.get_main_option("version_table") or "alembic_version" + if version_table: + with engine.connect() as connection: + print(f"Delete Alembic version table: {version_table}") + connection.execute(text(f"DROP TABLE IF EXISTS {version_table}")) + + print("Cleaned Alembic migration scripts and history") + + +_MIGRATION_SOLUTION = """ +**Solution 1:** + +Run the following command to upgrade the database. +```commandline +dbgpt db migration upgrade +``` + +**Solution 2:** + +Run the following command to clean the migration script and migration history. +```commandline +dbgpt db migration clean -y +``` + +**Solution 3:** + +If you have already run the above command, but the error still exists, +you can try the following command to clean the migration script, migration history and your data. +warning: This command will delete all your data!!! Please use it with caution. + +```commandline +dbgpt db migration clean --drop_all_tables -y --confirm_drop_all_tables +``` +or +```commandline +rm -rf pilot/meta_data/alembic/versions/* +rm -rf pilot/meta_data/alembic/dbgpt.db +``` +""" + + +def _ddl_init_and_upgrade( + default_meta_data_path: str, + disable_alembic_upgrade: bool, + alembic_ini_path: Optional[str] = None, + script_location: Optional[str] = None, +): + """Initialize and upgrade database metadata + + Args: + default_meta_data_path (str): default meta data path + disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata + alembic_ini_path (Optional[str]): alembic ini path + script_location (Optional[str]): alembic script location + """ + if disable_alembic_upgrade: + logger.info( + "disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic" + ) + return + else: + warn_msg = ( + "Initialize and upgrade database metadata with alembic, " + "just run this in your development environment, if you deploy this in production environment, " + "please run webserver with --disable_alembic_upgrade(`python dbgpt/app/dbgpt_server.py " + "--disable_alembic_upgrade`).\n" + "we suggest you to use `dbgpt db migration` to initialize and upgrade database metadata with alembic, " + "your can run `dbgpt db migration --help` to get more information." + ) + logger.warning(warn_msg) + from dbgpt.storage.metadata.db_manager import db + + alembic_cfg = create_alembic_config( + default_meta_data_path, + db.engine, + db.Model, + db.session(), + alembic_ini_path, + script_location, + ) + try: + create_migration_script(alembic_cfg, db.engine) + upgrade_database(alembic_cfg, db.engine) + except CommandError as e: + if "Target database is not up to date" in str(e): + logger.error( + f"Initialize and upgrade database metadata with alembic failed, error detail: {str(e)} " + f"you can try the following solutions:\n{_MIGRATION_SOLUTION}\n" + ) + raise Exception( + "Initialize and upgrade database metadata with alembic failed, " + "you can see the error and solutions above" + ) from e + else: + raise e diff --git a/dbgpt/util/annotations.py b/dbgpt/util/annotations.py index 4b64ee943..f97075339 100644 --- a/dbgpt/util/annotations.py +++ b/dbgpt/util/annotations.py @@ -39,6 +39,31 @@ def decorator(obj): return decorator +def DeveloperAPI(*args, **kwargs): + """Decorator to mark a function or class as a developer API. + + Developer APIs are low-level APIs for advanced users and may change cross major versions. + + Examples: + >>> from dbgpt.util.annotations import DeveloperAPI + >>> @DeveloperAPI + ... def foo(): + ... pass + + """ + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return DeveloperAPI()(args[0]) + + def decorator(obj): + _modify_docstring( + obj, + "**DeveloperAPI:** This API is for advanced users and may change cross major versions.", + ) + return obj + + return decorator + + def _modify_docstring(obj, message: str = None): if not message: return diff --git a/dbgpt/util/pagination_utils.py b/dbgpt/util/pagination_utils.py new file mode 100644 index 000000000..cbe21dda0 --- /dev/null +++ b/dbgpt/util/pagination_utils.py @@ -0,0 +1,14 @@ +from typing import TypeVar, Generic, List +from dbgpt._private.pydantic import BaseModel, Field + +T = TypeVar("T") + + +class PaginationResult(BaseModel, Generic[T]): + """Pagination result""" + + items: List[T] = Field(..., description="The items in the current page") + total_count: int = Field(..., description="Total number of items") + total_pages: int = Field(..., description="total number of pages") + page: int = Field(..., description="Current page number") + page_size: int = Field(..., description="Number of items per page") diff --git a/dbgpt/util/serialization/json_serialization.py b/dbgpt/util/serialization/json_serialization.py index 4f0f90fd6..58811cae2 100644 --- a/dbgpt/util/serialization/json_serialization.py +++ b/dbgpt/util/serialization/json_serialization.py @@ -41,4 +41,6 @@ def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable: # Convert bytes back to JSON and then to the specified class json_data = json.loads(data.decode(JSON_ENCODING)) # Assume that the cls has an __init__ that accepts a dictionary - return cls(**json_data) + obj = cls(**json_data) + obj.set_serializer(self) + return obj diff --git a/dbgpt/util/string_utils.py b/dbgpt/util/string_utils.py index 170f0519a..d62b10bea 100644 --- a/dbgpt/util/string_utils.py +++ b/dbgpt/util/string_utils.py @@ -73,9 +73,11 @@ def extract_content_open_ending(long_string, s1, s2, is_include: bool = False): return match_map -if __name__ == "__main__": - s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123" - s1 = "123" - s2 = "456" +def _to_str(x, charset="utf8", errors="strict"): + if x is None or isinstance(x, str): + return x - print(extract_content_open_ending(s, s1, s2, True)) + if isinstance(x, bytes): + return x.decode(charset, errors) + + return str(x) diff --git a/docs/docs/faq/install.md b/docs/docs/faq/install.md index 8ca0f6ec1..a4b05e23b 100644 --- a/docs/docs/faq/install.md +++ b/docs/docs/faq/install.md @@ -71,3 +71,66 @@ Download and install `Microsoft C++ Build Tools` from [visual-cpp-build-tools](h 1. update your mysql username and password in docker/examples/metadata/duckdb2mysql.py 2. python docker/examples/metadata/duckdb2mysql.py ``` + +##### Q8: `How to manage and migrate my database` + +You can use the command of `dbgpt db migration` to manage and migrate your database. + +See the following command for details. +```commandline +dbgpt db migration --help +``` + +First, you need to create a migration script(just once unless you clean it). +This command with create a `alembic` directory in your `pilot/meta_data` directory and a initial migration script in it. +```commandline +dbgpt db migration init +``` + +Then you can upgrade your database with the following command. +```commandline +dbgpt db migration upgrade +``` + +Every time you change the model or pull the latest code from DB-GPT repository, you need to create a new migration script. +```commandline + +dbgpt db migration migrate -m "your message" +``` + +Then you can upgrade your database with the following command. +```commandline +dbgpt db migration upgrade +``` + + +##### Q9: `alembic.util.exc.CommandError: Target database is not up to date.` + +**Solution 1:** + +Run the following command to upgrade the database. +```commandline +dbgpt db migration upgrade +``` + +**Solution 2:** + +Run the following command to clean the migration script and migration history. +```commandline +dbgpt db migration clean -y +``` + +**Solution 3:** + +If you have already run the above command, but the error still exists, +you can try the following command to clean the migration script, migration history and your data. +warning: This command will delete all your data!!! Please use it with caution. + +```commandline +dbgpt db migration clean --drop_all_tables -y --confirm_drop_all_tables +``` +or +```commandline +rm -rf pilot/meta_data/alembic/versions/* +rm -rf pilot/meta_data/alembic/dbgpt.db +``` \ No newline at end of file diff --git a/docs/docs/installation/sourcecode.md b/docs/docs/installation/sourcecode.md index f8eb42b65..6a08eb3ea 100644 --- a/docs/docs/installation/sourcecode.md +++ b/docs/docs/installation/sourcecode.md @@ -97,6 +97,8 @@ Configure the proxy and modify LLM_MODEL, PROXY_API_URL and API_KEY in the `.env LLM_MODEL=chatgpt_proxyllm PROXY_API_KEY={your-openai-sk} PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions +# If you use gpt-4 +# PROXYLLM_BACKEND=gpt-4 ``` diff --git a/pilot/meta_data/alembic/env.py b/pilot/meta_data/alembic/env.py index ef2e26a75..5ed4384a9 100644 --- a/pilot/meta_data/alembic/env.py +++ b/pilot/meta_data/alembic/env.py @@ -3,7 +3,7 @@ from alembic import context -from dbgpt.storage.metadata.meta_data import Base, engine +from dbgpt.storage.metadata.db_manager import db # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -13,8 +13,7 @@ # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = Base.metadata + # other values from the config, defined by the needs of env.py, # can be acquired: @@ -34,6 +33,8 @@ def run_migrations_offline() -> None: script output. """ + engine = db.engine + target_metadata = db.metadata url = config.get_main_option(engine.url) context.configure( url=url, @@ -53,12 +54,8 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - + engine = db.engine + target_metadata = db.metadata with engine.connect() as connection: if engine.dialect.name == "sqlite": context.configure(