From 02f23108b731c8502b11f0b6abcd11c0ed42c567 Mon Sep 17 00:00:00 2001 From: zhouchangda Date: Tue, 17 Dec 2024 04:34:27 +0000 Subject: [PATCH] support openai API chatbot --- .../configs/pipelines/PP-ChatOCRv3-doc.yaml | 7 + paddlex/inference/pipelines_new/__init__.py | 8 +- .../components/chat_server/__init__.py | 1 + .../components/chat_server/ernie_bot_chat.py | 9 +- .../components/chat_server/openai_bot_chat.py | 127 ++++++++++++ .../components/retriever/__init__.py | 1 + .../retriever/ernie_bot_retriever.py | 11 +- .../retriever/openai_bot_retriever.py | 181 ++++++++++++++++++ .../pp_chatocrv3_doc/pipeline.py | 36 +++- 9 files changed, 370 insertions(+), 11 deletions(-) create mode 100644 paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py create mode 100644 paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py diff --git a/paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml b/paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml index 8c51728ca..384a8cbeb 100644 --- a/paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +++ b/paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml @@ -10,6 +10,13 @@ SubModules: ak: "api_key" # Set this to a real API key sk: "secret_key" # Set this to a real secret key + MLLM_Chat: + module_name: chat_bot + model_name: PP-DocBee + base_url: "http://127.0.0.1/v1/chat/completions" + api_type: openai + api_key: "api_key" + LLM_Retriever: module_name: retriever model_name: ernie-3.5 diff --git a/paddlex/inference/pipelines_new/__init__.py b/paddlex/inference/pipelines_new/__init__.py index 768d236d7..d85386a38 100644 --- a/paddlex/inference/pipelines_new/__init__.py +++ b/paddlex/inference/pipelines_new/__init__.py @@ -151,8 +151,8 @@ def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat: Returns: BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config. """ - model_name = config["model_name"] - chat_bot = BaseChat.get(model_name)(config) + api_type = config["api_type"] + chat_bot = BaseChat.get(api_type)(config) return chat_bot @@ -172,8 +172,8 @@ def create_retriever( Returns: BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config. """ - model_name = config["model_name"] - retriever = BaseRetriever.get(model_name)(config) + api_type = config["api_type"] + retriever = BaseRetriever.get(api_type)(config) return retriever diff --git a/paddlex/inference/pipelines_new/components/chat_server/__init__.py b/paddlex/inference/pipelines_new/components/chat_server/__init__.py index a08bf1933..5149d5c15 100644 --- a/paddlex/inference/pipelines_new/components/chat_server/__init__.py +++ b/paddlex/inference/pipelines_new/components/chat_server/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .ernie_bot_chat import ErnieBotChat +from .openai_bot_chat import OpenAIBotChat diff --git a/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py b/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py index 57c42e54e..3e0d3b96a 100644 --- a/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py +++ b/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py @@ -22,6 +22,11 @@ class ErnieBotChat(BaseChat): """Ernie Bot Chat""" entities = [ + "aistudio", + "qianfan", + ] + + MODELS = [ "ernie-4.0", "ernie-3.5", "ernie-3.5-8k", @@ -51,8 +56,8 @@ def __init__(self, config: Dict) -> None: sk = config.get("sk", None) access_token = config.get("access_token", None) - if model_name not in self.entities: - raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.") + if model_name not in self.MODELS: + raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.") if api_type not in ["aistudio", "qianfan"]: raise ValueError("api_type must be one of ['aistudio', 'qianfan']") diff --git a/paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py b/paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py new file mode 100644 index 000000000..8b1dadc05 --- /dev/null +++ b/paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py @@ -0,0 +1,127 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .....utils import logging +from .base import BaseChat +import json +import base64 +from typing import Dict + + +class OpenAIBotChat(BaseChat): + """OpenAI Bot Chat""" + + entities = [ + "openai", + ] + + def __init__(self, config: Dict) -> None: + """Initializes the OpenAIBotChat with given configuration. + + Args: + config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key. + + Raises: + ValueError: If api_type is not one of ['openai'], + base_url is None for api_type is openai, + api_key is None for api_type is openai. + """ + super().__init__() + model_name = config.get("model_name", None) + api_type = config.get("api_type", None) + api_key = config.get("api_key", None) + base_url = config.get("base_url", None) + + if api_type not in ["openai"]: + raise ValueError("api_type must be one of ['openai']") + + if api_type == "openai" and api_key is None: + raise ValueError("api_key cannot be empty when api_type is openai.") + + if base_url is None: + raise ValueError("base_url cannot be empty when api_type is openai.") + + try: + from openai import OpenAI + except: + raise Exception("openai is not installed, please install it first.") + + self.client = OpenAI(base_url=base_url, api_key=api_key) + + self.model_name = model_name + self.config = config + + def generate_chat_results( + self, + prompt: str, + image: base64 = None, + temperature: float = 0.001, + max_retries: int = 1, + ) -> Dict: + """ + Generate chat results using the specified model and configuration. + + Args: + prompt (str): The user's input prompt. + image (base64): The user's input image for MLLM, defaults to None. + temperature (float, optional): The temperature parameter for llms, defaults to 0.001. + max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1. + + Returns: + Dict: The chat completion result from the model. + """ + try: + if image: + chat_completion = self.client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "system", + # XXX: give a basic prompt for common + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image}" + }, + }, + ], + }, + ], + stream=False, + ) + llm_result = chat_completion.choices[0].message.content + else: + chat_completion = self.client.completions.create( + model=self.model_name, + prompt=prompt, + max_tokens=self.config.get("max_tokens", 1024), + temperature=float(temperature), + stream=False, + ) + if isinstance(chat_completion, str): + chat_completion = json.loads(chat_completion) + llm_result = chat_completion["choices"][0]["text"] + else: + llm_result = chat_completion.choices[0].text + return llm_result + except Exception as e: + logging.error(e) + self.ERROR_MASSAGE = "大模型调用失败" + return None diff --git a/paddlex/inference/pipelines_new/components/retriever/__init__.py b/paddlex/inference/pipelines_new/components/retriever/__init__.py index f829efe0c..8d3f52e25 100644 --- a/paddlex/inference/pipelines_new/components/retriever/__init__.py +++ b/paddlex/inference/pipelines_new/components/retriever/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .ernie_bot_retriever import ErnieBotRetriever +from .openai_bot_retriever import OpenAIBotRetriever diff --git a/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py b/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py index 0eb2719f4..e87b3773e 100644 --- a/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py +++ b/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py @@ -32,6 +32,11 @@ class ErnieBotRetriever(BaseRetriever): """Ernie Bot Retriever""" entities = [ + "aistudio", + "qianfan", + ] + + MODELS = [ "ernie-4.0", "ernie-3.5", "ernie-3.5-8k", @@ -49,7 +54,7 @@ def __init__(self, config: Dict) -> None: Args: config (Dict): A dictionary containing configuration settings. - model_name (str): The name of the model to use. - - api_type (str): The type of API to use ('aistudio' or 'qianfan'). + - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai'). - ak (str, optional): The access key for 'qianfan' API. - sk (str, optional): The secret key for 'qianfan' API. - access_token (str, optional): The access token for 'aistudio' API. @@ -68,8 +73,8 @@ def __init__(self, config: Dict) -> None: sk = config.get("sk", None) access_token = config.get("access_token", None) - if model_name not in self.entities: - raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.") + if model_name not in self.MODELS: + raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.") if api_type not in ["aistudio", "qianfan"]: raise ValueError("api_type must be one of ['aistudio', 'qianfan']") diff --git a/paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py b/paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py new file mode 100644 index 000000000..69ff797f9 --- /dev/null +++ b/paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py @@ -0,0 +1,181 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseRetriever + +from langchain.docstore.document import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import FAISS +from langchain_community import vectorstores + +import time + +from typing import Dict + + +class OpenAIBotRetriever(BaseRetriever): + """OpenAI Bot Retriever""" + + entities = [ + "openai", + ] + + def __init__(self, config: Dict) -> None: + """ + Initializes the OpenAIBotRetriever instance with the provided configuration. + + Args: + config (Dict): A dictionary containing configuration settings. + - model_name (str): The name of the model to use. + - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai'). + - api_key (str, optional): The API key for 'openai' API. + - base_url (str, optional): The base URL for 'openai' API. + + Raises: + ValueError: If api_type is not one of ['openai'], + base_url is None for api_type is openai, + api_key is None for api_type is openai. + """ + super().__init__() + + model_name = config.get("model_name", None) + api_type = config.get("api_type", None) + api_key = config.get("api_key", None) + base_url = config.get("base_url", None) + tiktoken_enabled = config.get("tiktoken_enabled", False) + + if api_type not in ["openai"]: + raise ValueError("api_type must be one of ['openai']") + + if api_type == "openai" and api_key is None: + raise ValueError("api_key cannot be empty when api_type is openai.") + + if base_url is None: + raise ValueError("base_url cannot be empty when api_type is openai.") + + try: + from langchain_openai import OpenAIEmbeddings + except: + raise Exception( + "langchain-openai is not installed, please install it first." + ) + + self.embedding = OpenAIEmbeddings( + model=model_name, + api_key=api_key, + base_url=base_url, + tiktoken_enabled=tiktoken_enabled, + ) + + self.model_name = model_name + self.config = config + + # Generates a vector database from a list of texts using different embeddings based on the configured API type. + + def generate_vector_database( + self, + text_list: list[str], + block_size: int = 300, + separators: list[str] = ["\t", "\n", "。", "\n\n", ""], + sleep_time: float = 0.5, + ) -> FAISS: + """ + Generates a vector database from a list of texts. + + Args: + text_list (list[str]): A list of texts to generate the vector database from. + block_size (int): The size of each chunk to split the text into. + separators (list[str]): A list of separators to use when splitting the text. + sleep_time (float): The time to sleep between embedding generations to avoid rate limiting. + + Returns: + FAISS: The generated vector database. + + Raises: + ValueError: If an unsupported API type is configured. + """ + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=block_size, chunk_overlap=20, separators=separators + ) + texts = text_splitter.split_text("\t".join(text_list)) + all_splits = [Document(page_content=text) for text in texts] + + api_type = self.config["api_type"] + + vectorstore = FAISS.from_documents( + documents=all_splits, embedding=self.embedding + ) + + return vectorstore + + def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str: + """ + Encode the vector store serialized to bytes. + + Args: + vectorstore (FAISS): The vector store to be serialized and encoded. + + Returns: + str: The encoded vector store. + """ + vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes()) + return vectorstore + + def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS: + """ + Decode a vector store from bytes according to the specified API type. + + Args: + vectorstore (str): The serialized vector store string. + + Returns: + FAISS: Deserialized vector store object. + + Raises: + ValueError: If the retrieved vector store is not for PaddleX + or if an unsupported API type is specified. + """ + if not self.is_vector_store(vectorstore): + raise ValueError("The retrieved vectorstore is not for PaddleX.") + + vector = vectorstores.FAISS.deserialize_from_bytes( + self.decode_vector_store(vectorstore), self.embedding + ) + return vector + + def similarity_retrieval( + self, query_text_list: list[str], vectorstore: FAISS, sleep_time: float = 0.5 + ) -> str: + """ + Retrieve similar contexts based on a list of query texts. + + Args: + query_text_list (list[str]): A list of query texts to search for similar contexts. + vectorstore (FAISS): The vector store where to perform the similarity search. + sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5. + + Returns: + str: A concatenated string of all unique contexts found. + """ + C = [] + for query_text in query_text_list: + QUESTION = query_text + time.sleep(sleep_time) + docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2) + context = [(document.page_content, score) for document, score in docs] + context = sorted(context, key=lambda x: x[1]) + C.extend([x[0] for x in context[::-1]]) + C = list(set(C)) + all_C = " ".join(C) + return all_C diff --git a/paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py b/paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py index 12bca0625..f16410fa0 100644 --- a/paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py +++ b/paddlex/inference/pipelines_new/pp_chatocrv3_doc/pipeline.py @@ -17,7 +17,8 @@ from typing import Any, Dict, Optional # import numpy as np -# import cv2 +import cv2 +import base64 from .result import VisualInfoResult import re @@ -62,10 +63,10 @@ def __init__( super().__init__( device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params ) - self.use_layout_parsing = use_layout_parsing self.inintial_predictor(config) + logging.warning("use_new_pipeline") self.img_reader = ReadImage(format="BGR") @@ -91,6 +92,9 @@ def inintial_predictor(self, config: dict) -> None: chat_bot_config = config["SubModules"]["LLM_Chat"] self.chat_bot = create_chat_bot(chat_bot_config) + mllm_chat_bot_config = config["SubModules"]["MLLM_Chat"] + self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config) + from .. import create_retriever retriever_config = config["SubModules"]["LLM_Retriever"] @@ -217,6 +221,34 @@ def visual_predict( } yield visual_predict_res + def mllm_pred( + self, + input: str | list[str] | np.ndarray | list[np.ndarray], + prompt, + **kwargs, + ) -> dict: + if not isinstance(input, list): + input_list = [input] + else: + input_list = input + + for input in input_list: + if isinstance(input, str): + image_array = next(self.img_reader(input))[0]["img"] + else: + image_array = input + + assert len(image_array.shape) == 3 + image_string = cv2.imencode(".jpg", image_array)[1].tostring() + image_base64 = base64.b64encode(image_string).decode("utf-8") + + mllm_chat_bot_result = self.mllm_chat_bot.generate_chat_results( + prompt=prompt, image=image_base64 + ) + mllm_result = {"mllm_info": mllm_chat_bot_result} + logging.info(mllm_result) + yield mllm_result + def save_visual_info_list( self, visual_info: VisualInfoResult, save_path: str ) -> None: