diff --git a/.mypy.ini b/.mypy.ini index 91f6d0c19..7ab4d670e 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -100,3 +100,6 @@ ignore_missing_imports = True [mypy-rich.*] ignore_missing_imports = True + +[mypy-ollama.*] +ignore_missing_imports = True diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index a9f75bb98..65809086f 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -69,6 +69,7 @@ def get_device() -> str: "yi_proxyllm": "yi_proxyllm", # https://platform.moonshot.cn/docs/ "moonshot_proxyllm": "moonshot_proxyllm", + "ollama_proxyllm": "ollama_proxyllm", "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), @@ -198,6 +199,7 @@ def get_device() -> str: "proxy_azure": "proxy_azure", # Common HTTP embedding model "proxy_http_openapi": "proxy_http_openapi", + "proxy_ollama": "proxy_ollama", } diff --git a/dbgpt/model/adapter/embeddings_loader.py b/dbgpt/model/adapter/embeddings_loader.py index 07885c5d2..6b97c9c3d 100644 --- a/dbgpt/model/adapter/embeddings_loader.py +++ b/dbgpt/model/adapter/embeddings_loader.py @@ -50,6 +50,16 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin if proxy_param.proxy_backend: openapi_param["model_name"] = proxy_param.proxy_backend return OpenAPIEmbeddings(**openapi_param) + elif model_name in ["proxy_ollama"]: + from dbgpt.rag.embedding import OllamaEmbeddings + + proxy_param = cast(ProxyEmbeddingParameters, param) + ollama_param = {} + if proxy_param.proxy_server_url: + ollama_param["api_url"] = proxy_param.proxy_server_url + if proxy_param.proxy_backend: + ollama_param["model_name"] = proxy_param.proxy_backend + return OllamaEmbeddings(**ollama_param) else: from dbgpt.rag.embedding import HuggingFaceEmbeddings diff --git a/dbgpt/model/adapter/proxy_adapter.py b/dbgpt/model/adapter/proxy_adapter.py index 42c4480f6..11d658aeb 100644 --- a/dbgpt/model/adapter/proxy_adapter.py +++ b/dbgpt/model/adapter/proxy_adapter.py @@ -114,6 +114,23 @@ def get_generate_stream_function(self, model, model_path: str): return tongyi_generate_stream +class OllamaLLMModelAdapter(ProxyLLMModelAdapter): + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "ollama_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.ollama import OllamaLLMClient + + return OllamaLLMClient + + def get_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.ollama import ollama_generate_stream + + return ollama_generate_stream + + class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter): support_system_message = False @@ -279,6 +296,7 @@ def get_async_generate_stream_function(self, model, model_path: str): register_model_adapter(OpenAIProxyLLMModelAdapter) register_model_adapter(TongyiProxyLLMModelAdapter) +register_model_adapter(OllamaLLMModelAdapter) register_model_adapter(ZhipuProxyLLMModelAdapter) register_model_adapter(WenxinProxyLLMModelAdapter) register_model_adapter(GeminiProxyLLMModelAdapter) diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 7af18a0a9..470d46b9a 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -556,7 +556,7 @@ def build_kwargs(self, **kwargs) -> Dict: _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { - ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi", + ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama", } EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} diff --git a/dbgpt/model/proxy/__init__.py b/dbgpt/model/proxy/__init__.py index 831456fbd..1c953ba23 100644 --- a/dbgpt/model/proxy/__init__.py +++ b/dbgpt/model/proxy/__init__.py @@ -11,6 +11,7 @@ def __lazy_import(name): "ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu", "YiLLMClient": "dbgpt.model.proxy.llms.yi", "MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot", + "OllamaLLMClient": "dbgpt.model.proxy.llms.ollama", } if name in module_path: @@ -33,4 +34,5 @@ def __getattr__(name): "SparkLLMClient", "YiLLMClient", "MoonshotLLMClient", + "OllamaLLMClient", ] diff --git a/dbgpt/model/proxy/llms/ollama.py b/dbgpt/model/proxy/llms/ollama.py new file mode 100644 index 000000000..d48a6b1b2 --- /dev/null +++ b/dbgpt/model/proxy/llms/ollama.py @@ -0,0 +1,101 @@ +import logging +from concurrent.futures import Executor +from typing import Iterator, Optional + +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ProxyLLMClient +from dbgpt.model.proxy.llms.proxy_model import ProxyModel + +logger = logging.getLogger(__name__) + + +def ollama_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=4096 +): + client: OllamaLLMClient = model.proxy_llm_client + context = ModelRequestContext(stream=True, user_name=params.get("user_name")) + request = ModelRequest.build_request( + client.default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), + ) + for r in client.sync_generate_stream(request): + yield r + + +class OllamaLLMClient(ProxyLLMClient): + def __init__( + self, + model: Optional[str] = None, + host: Optional[str] = None, + model_alias: Optional[str] = "ollama_proxyllm", + context_length: Optional[int] = 4096, + executor: Optional[Executor] = None, + ): + if not model: + model = "llama2" + if not host: + host = "http://localhost:11434" + self._model = model + self._host = host + + super().__init__( + model_names=[model, model_alias], + context_length=context_length, + executor=executor, + ) + + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "OllamaLLMClient": + return cls( + model=model_params.proxyllm_backend, + host=model_params.proxy_server_url, + model_alias=model_params.model_name, + context_length=model_params.max_context_size, + executor=default_executor, + ) + + @property + def default_model(self) -> str: + return self._model + + def sync_generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> Iterator[ModelOutput]: + try: + import ollama + from ollama import Client + except ImportError as e: + raise ValueError( + "Could not import python package: ollama " + "Please install ollama by command `pip install ollama" + ) from e + request = self.local_covert_message(request, message_converter) + messages = request.to_common_messages() + + model = request.model or self._model + client = Client(self._host) + try: + stream = client.chat( + model=model, + messages=messages, + stream=True, + ) + content = "" + for chunk in stream: + content = content + chunk["message"]["content"] + yield ModelOutput(text=content, error_code=0) + except ollama.ResponseError as e: + return ModelOutput( + text=f"**Ollama Response Error, Please CheckErrorInfo.**: {e}", + error_code=-1, + ) diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index 165160b01..c60f01d6e 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -12,6 +12,7 @@ HuggingFaceInferenceAPIEmbeddings, HuggingFaceInstructEmbeddings, JinaEmbeddings, + OllamaEmbeddings, OpenAPIEmbeddings, ) @@ -23,6 +24,7 @@ "HuggingFaceInstructEmbeddings", "JinaEmbeddings", "OpenAPIEmbeddings", + "OllamaEmbeddings", "DefaultEmbeddingFactory", "EmbeddingFactory", "WrappedEmbeddingFactory", diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 59d08b9fd..b51456313 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -736,3 +736,94 @@ async def aembed_query(self, text: str) -> List[float]: """Asynchronous Embed query text.""" embeddings = await self.aembed_documents([text]) return embeddings[0] + + +class OllamaEmbeddings(BaseModel, Embeddings): + """Ollama proxy embeddings. + + This class is used to get embeddings for a list of texts using the Ollama API. + It requires a proxy server url `api_url` and a model name `model_name`. + The default model name is "llama2". + """ + + api_url: str = Field( + default="http://localhost:11434", + description="The URL of the embeddings API.", + ) + model_name: str = Field( + default="llama2", description="The name of the model to use." + ) + + def __init__(self, **kwargs): + """Initialize the OllamaEmbeddings.""" + super().__init__(**kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embedded texts as List[List[float]], where each inner List[float] + corresponds to a single input text. + """ + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a OpenAPI embedding model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + try: + import ollama + from ollama import Client + except ImportError as e: + raise ValueError( + "Could not import python package: ollama " + "Please install ollama by command `pip install ollama" + ) from e + try: + return ( + Client(self.api_url).embeddings(model=self.model_name, prompt=text) + )["embedding"] + except ollama.ResponseError as e: + raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}") + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronous Embed search docs. + + Args: + texts: A list of texts to get embeddings for. + + Returns: + List[List[float]]: Embedded texts as List[List[float]], where each inner + List[float] corresponds to a single input text. + """ + embeddings = [] + for text in texts: + embedding = await self.aembed_query(text) + embeddings.append(embedding) + return embeddings + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous Embed query text.""" + try: + import ollama + from ollama import AsyncClient + except ImportError: + raise ValueError( + "The ollama python package is not installed. " + "Please install it with `pip install ollama`" + ) + try: + embedding = await AsyncClient(host=self.api_url).embeddings( + model=self.model_name, prompt=text + ) + return embedding["embedding"] + except ollama.ResponseError as e: + raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}") diff --git a/setup.py b/setup.py index f851eb821..1dfe77bf2 100644 --- a/setup.py +++ b/setup.py @@ -658,6 +658,7 @@ def default_requires(): "dashscope", "chardet", "sentencepiece", + "ollama", ] setup_spec.extras["default"] += setup_spec.extras["framework"] setup_spec.extras["default"] += setup_spec.extras["rag"]