diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index de23a3cd4..e5641fe7c 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -92,7 +92,7 @@ def __init__(self) -> None: # baichuan proxy self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY") self.bc_model_name = os.getenv("BAICHUN_MODEL_NAME", "Baichuan2-Turbo-192k") - if self.bc_proxy_api_key and self.bc_proxy_api_secret: + if self.bc_proxy_api_key and self.bc_model_name: os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_name diff --git a/dbgpt/model/proxy/llms/baichuan.py b/dbgpt/model/proxy/llms/baichuan.py index 568924f4b..ed641c72f 100644 --- a/dbgpt/model/proxy/llms/baichuan.py +++ b/dbgpt/model/proxy/llms/baichuan.py @@ -8,7 +8,9 @@ BAICHUAN_DEFAULT_MODEL = "Baichuan2-Turbo-192k" -def baichuan_generate_stream(model: ProxyModel, tokenizer=None, params=None, device=None, context_len=4096): +def baichuan_generate_stream( + model: ProxyModel, tokenizer=None, params=None, device=None, context_len=4096 +): url = "https://api.baichuan-ai.com/v1/chat/completions" model_params = model.get_params() @@ -63,22 +65,27 @@ def baichuan_generate_stream(model: ProxyModel, tokenizer=None, params=None, dev text += content yield text + def main(): model_params = ProxyModelParameters( model_name="not-used", model_path="not-used", proxy_server_url="not-used", proxy_api_key="YOUR_BAICHUAN_API_KEY", - proxyllm_backend="Baichuan2-Turbo-192k" + proxyllm_backend="Baichuan2-Turbo-192k", ) final_text = "" for part in baichuan_generate_stream( model=ProxyModel(model_params=model_params), - params={"messages": [ModelMessage( - role=ModelMessageRoleType.HUMAN, - content="背诵《论语》第一章")]}): + params={ + "messages": [ + ModelMessage(role=ModelMessageRoleType.HUMAN, content="背诵《论语》第一章") + ] + }, + ): final_text = part print(final_text) + if __name__ == "__main__": main()