Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Donot Merge] support openai API chatbot #2675

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ SubModules:
ak: "api_key" # Set this to a real API key
sk: "secret_key" # Set this to a real secret key

MLLM_Chat:
module_name: chat_bot
model_name: PP-DocBee
base_url: "http://127.0.0.1/v1/chat/completions"
api_type: openai
api_key: "api_key"

LLM_Retriever:
module_name: retriever
model_name: ernie-3.5
Expand Down
8 changes: 4 additions & 4 deletions paddlex/inference/pipelines_new/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
Returns:
BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
"""
model_name = config["model_name"]
chat_bot = BaseChat.get(model_name)(config)
api_type = config["api_type"]
chat_bot = BaseChat.get(api_type)(config)
return chat_bot


Expand All @@ -172,8 +172,8 @@ def create_retriever(
Returns:
BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
"""
model_name = config["model_name"]
retriever = BaseRetriever.get(model_name)(config)
api_type = config["api_type"]
retriever = BaseRetriever.get(api_type)(config)
return retriever


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .ernie_bot_chat import ErnieBotChat
from .openai_bot_chat import OpenAIBotChat
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class ErnieBotChat(BaseChat):
"""Ernie Bot Chat"""

entities = [
"aistudio",
"qianfan",
]

MODELS = [
"ernie-4.0",
"ernie-3.5",
"ernie-3.5-8k",
Expand Down Expand Up @@ -51,8 +56,8 @@ def __init__(self, config: Dict) -> None:
sk = config.get("sk", None)
access_token = config.get("access_token", None)

if model_name not in self.entities:
raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
if model_name not in self.MODELS:
raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.")

if api_type not in ["aistudio", "qianfan"]:
raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .....utils import logging
from .base import BaseChat
import json
import base64
from typing import Dict


class OpenAIBotChat(BaseChat):
"""OpenAI Bot Chat"""

entities = [
"openai",
]

def __init__(self, config: Dict) -> None:
"""Initializes the OpenAIBotChat with given configuration.

Args:
config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key.

Raises:
ValueError: If api_type is not one of ['openai'],
base_url is None for api_type is openai,
api_key is None for api_type is openai.
"""
super().__init__()
model_name = config.get("model_name", None)
api_type = config.get("api_type", None)
api_key = config.get("api_key", None)
base_url = config.get("base_url", None)

if api_type not in ["openai"]:
raise ValueError("api_type must be one of ['openai']")

if api_type == "openai" and api_key is None:
raise ValueError("api_key cannot be empty when api_type is openai.")

if base_url is None:
raise ValueError("base_url cannot be empty when api_type is openai.")

try:
from openai import OpenAI
except:
raise Exception("openai is not installed, please install it first.")

self.client = OpenAI(base_url=base_url, api_key=api_key)

self.model_name = model_name
self.config = config

def generate_chat_results(
self,
prompt: str,
image: base64 = None,
temperature: float = 0.001,
max_retries: int = 1,
) -> Dict:
"""
Generate chat results using the specified model and configuration.

Args:
prompt (str): The user's input prompt.
image (base64): The user's input image for MLLM, defaults to None.
temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.

Returns:
Dict: The chat completion result from the model.
"""
try:
if image:
chat_completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "system",
# XXX: give a basic prompt for common
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image}"
},
},
],
},
],
stream=False,
)
llm_result = chat_completion.choices[0].message.content
else:
chat_completion = self.client.completions.create(
model=self.model_name,
prompt=prompt,
max_tokens=self.config.get("max_tokens", 1024),
temperature=float(temperature),
stream=False,
)
if isinstance(chat_completion, str):
chat_completion = json.loads(chat_completion)
llm_result = chat_completion["choices"][0]["text"]
else:
llm_result = chat_completion.choices[0].text
return llm_result
except Exception as e:
logging.error(e)
self.ERROR_MASSAGE = "大模型调用失败"
return None
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .ernie_bot_retriever import ErnieBotRetriever
from .openai_bot_retriever import OpenAIBotRetriever
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class ErnieBotRetriever(BaseRetriever):
"""Ernie Bot Retriever"""

entities = [
"aistudio",
"qianfan",
]

MODELS = [
"ernie-4.0",
"ernie-3.5",
"ernie-3.5-8k",
Expand All @@ -49,7 +54,7 @@ def __init__(self, config: Dict) -> None:
Args:
config (Dict): A dictionary containing configuration settings.
- model_name (str): The name of the model to use.
- api_type (str): The type of API to use ('aistudio' or 'qianfan').
- api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai').
- ak (str, optional): The access key for 'qianfan' API.
- sk (str, optional): The secret key for 'qianfan' API.
- access_token (str, optional): The access token for 'aistudio' API.
Expand All @@ -68,8 +73,8 @@ def __init__(self, config: Dict) -> None:
sk = config.get("sk", None)
access_token = config.get("access_token", None)

if model_name not in self.entities:
raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
if model_name not in self.MODELS:
raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.")

if api_type not in ["aistudio", "qianfan"]:
raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
Expand Down
Loading