diff --git a/src/cardinal/app/kbqa.py b/src/cardinal/app/kbqa.py index c91a5f1..b0b8299 100644 --- a/src/cardinal/app/kbqa.py +++ b/src/cardinal/app/kbqa.py @@ -1,6 +1,7 @@ import os from typing import TYPE_CHECKING, Generator, List +from ..core.function_calls.functions import execute_function_call, parse_function_availables from ..core.model import ChatOpenAI, EmbedOpenAI from ..core.retriever import BaseRetriever from ..core.schema import Leaf, LeafIndex, Template @@ -37,5 +38,15 @@ def __call__(self, messages: List["BaseMessage"]) -> Generator[str, None, None]: else: question = self._plain_template.apply(question=question) + # parse function availables + question, tools = parse_function_availables(question) + + # execute function calls + if tools: + function_call = self._chat_model.function_call(messages=messages, tools=tools) + # calling the function_calls, and get the response + response = execute_function_call(function_call) + yield response + messages[-1].content = question yield from self._chat_model.stream_chat(messages=messages) diff --git a/src/cardinal/core/function_calls/code_interpreter.py b/src/cardinal/core/function_calls/code_interpreter.py new file mode 100644 index 0000000..c68cc89 --- /dev/null +++ b/src/cardinal/core/function_calls/code_interpreter.py @@ -0,0 +1,91 @@ +import logging +import os +import re + +import json5 +import matplotlib + +from ..function_calls.code_kernel import CodeKernel + + +START_CODE = """ +import signal +def _m6_code_interpreter_timeout_handler(signum, frame): + raise TimeoutError("M6_CODE_INTERPRETER_TIMEOUT") +signal.signal(signal.SIGALRM, _m6_code_interpreter_timeout_handler) + +def input(*args, **kwargs): + raise NotImplementedError('Python input() function is disabled.') + +import math +import re +import json + +import seaborn as sns +sns.set_theme() + +import matplotlib +import matplotlib.pyplot as plt +plt.rcParams['font.sans-serif'] = ['SimHei'] +plt.rcParams['axes.unicode_minus'] = False + +import numpy as np +import pandas as pd + +from sympy import Eq, symbols, solve +""" + + +class CodeInterpreter: + def __init__(self): + self.code_kernel = CodeKernel() + + # ensure Chinese font support, before launching app + def fix_matplotlib_cjk_font_issue(self): + local_ttf = os.path.join( + os.path.abspath(os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)), "fonts", "ttf", "simhei.ttf" + ) + if not os.path.exists(local_ttf): + logging.warning( + f"Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib." + ) + + # extract code from text + def extract_code(self, text): + # Match triple backtick blocks first + triple_match = re.search(r"```[^\n]*\n(.+?)```", text, re.DOTALL) + # Match single backtick blocks second + single_match = re.search(r"`([^`]*)`", text, re.DOTALL) + if triple_match: + text = triple_match.group(1) + elif single_match: + text = single_match.group(1) + else: + try: + text = json5.loads(text)["code"] + except Exception: + pass + # If no code blocks found, return original text + return text + + def code_interpreter(self, code: str, timeout=30, start_code=True): + # convert action_input_list to code + code = self.extract_code(code) + "\n" + # add timeout + if timeout: + code = f"signal.alarm({timeout})\n{code}" + # add start code + if start_code: + code = START_CODE + "\n" + code + res_type, res = self.code_kernel.execute(code) + return res if res_type == "image" else None + + +if __name__ == "__main__": + ci = CodeInterpreter() + image_b64 = ci.code_interpreter( + [ + "```py\nimport matplotlib.pyplot as plt\n\n# 数据\ncountries = ['英军', '美军', '德军']\ntroops = [300000, 700000, 500000]\n\n# 绘制柱状图\nplt.bar(countries, troops)\nplt.xlabel('国家')\nplt.ylabel('军力(万)')\nplt.title('各国军力')\nplt.show()\n```\n" + ] + ) + print("res:\n", image_b64) diff --git a/src/cardinal/core/function_calls/code_kernel.py b/src/cardinal/core/function_calls/code_kernel.py new file mode 100644 index 0000000..57fe28c --- /dev/null +++ b/src/cardinal/core/function_calls/code_kernel.py @@ -0,0 +1,163 @@ +import os +import queue +import re +from pprint import pprint +from subprocess import PIPE + +import jupyter_client +from PIL import Image + + +# make sure run "ipython kernel install --name code_kernel --user" before using code kernel +IPYKERNEL = os.environ.get("IPYKERNEL", "code_kernel") + + +class CodeKernel(object): + def __init__( + self, + kernel_name="kernel", + kernel_id=None, + kernel_config_path="", + python_path=None, + ipython_path=None, + init_file_path="./startup.py", + verbose=1, + ): + self.kernel_name = kernel_name + self.kernel_id = kernel_id + self.kernel_config_path = kernel_config_path + self.python_path = python_path + self.ipython_path = ipython_path + self.init_file_path = init_file_path + self.verbose = verbose + + if python_path is None and ipython_path is None: + env = None + else: + env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path} + + # Initialize the backend kernel + self.kernel_manager = jupyter_client.KernelManager( + kernel_name=IPYKERNEL, connection_file=self.kernel_config_path, exec_files=[self.init_file_path], env=env + ) + if self.kernel_config_path: + self.kernel_manager.load_connection_file() + self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) + print("Backend kernel started with the configuration: {}".format(self.kernel_config_path)) + else: + self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) + print("Backend kernel started with the configuration: {}".format(self.kernel_manager.connection_file)) + + if verbose: + pprint(self.kernel_manager.get_connection_info()) + + # Initialize the code kernel + self.kernel = self.kernel_manager.blocking_client() + # self.kernel.load_connection_file() + self.kernel.start_channels() + print("Code kernel started.") + + def _execute(self, code): + self.kernel.execute(code) + try: + shell_msg = self.kernel.get_shell_msg(timeout=30) + io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"] + while True: + msg_out = io_msg_content + # Poll the message + try: + io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"] + if "execution_state" in io_msg_content and io_msg_content["execution_state"] == "idle": + break + except queue.Empty: + break + + return shell_msg, msg_out + except Exception as e: + print(e) + return None, None + + def clean_ansi_codes(self, input_string): + ansi_escape = re.compile(r"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", input_string) + + def execute(self, code) -> tuple[str, str | Image.Image]: + res = "" + res_type = None + code = code.replace("<|observation|>", "") + code = code.replace("<|assistant|>interpreter", "") + code = code.replace("<|assistant|>", "") + code = code.replace("<|user|>", "") + code = code.replace("<|system|>", "") + msg, output = self._execute(code) + if not msg and not output: + return res_type, "no output" + + if msg["metadata"]["status"] == "timeout": + return res_type, "Timed out" + elif msg["metadata"]["status"] == "error": + return res_type, self.clean_ansi_codes("\n".join(self.get_error_msg(msg, verbose=True))) + + if "text" in output: + res_type = "text" + res = output["text"] + elif "data" in output: + for key in output["data"]: + if "text/plain" in key: + res_type = "text" + res = output["data"][key] + elif "image/png" in key: + res_type = "image" + res = output["data"][key] + break + + if res_type == "text" or res_type == "traceback": + res = res + + return res_type, res + + def get_error_msg(self, msg, verbose=False) -> str | None: + if msg["content"]["status"] == "error": + try: + error_msg = msg["content"]["traceback"] + except Exception: + try: + error_msg = msg["content"]["traceback"][-1].strip() + except Exception: + error_msg = "Traceback Error" + if verbose: + print("Error: ", error_msg) + return error_msg + return None + + def shutdown(self): + # Shutdown the backend kernel + self.kernel_manager.shutdown_kernel() + print("Backend kernel shutdown.") + # Shutdown the code kernel + self.kernel.shutdown() + print("Code kernel shutdown.") + + def restart(self): + # Restart the backend kernel + self.kernel_manager.restart_kernel() + # print("Backend kernel restarted.") + + def interrupt(self): + # Interrupt the backend kernel + self.kernel_manager.interrupt_kernel() + # print("Backend kernel interrupted.") + + def is_alive(self): + return self.kernel.is_alive() + + +if __name__ == "__main__": + ck = CodeKernel() + # shell_msg, msg_out = ck.execute("print('hello world')") + res_type, res = ck.execute( + """import matplotlib.pyplot as plt\n\n# 数据\ncountries = ['英军', '美军', '德军']\ntroops = [300000, 700000, 500000]\n\n# 绘制柱状图\nplt.bar(countries, troops)\nplt.xlabel('国家')\nplt.ylabel('军力(万)')\nplt.title('各国军力')\nplt.show()""", + CodeKernel(), + ) + print("res_type:\n", res_type, "\n") + print("res:\n", res, "\n") diff --git a/src/cardinal/core/function_calls/functions.py b/src/cardinal/core/function_calls/functions.py new file mode 100644 index 0000000..de08520 --- /dev/null +++ b/src/cardinal/core/function_calls/functions.py @@ -0,0 +1,81 @@ +from typing import List + +from ..function_calls.code_interpreter import CodeInterpreter +from ..schema import FunctionAvailable, FunctionCall + + +# function description + +CI = FunctionAvailable( + function={ + "name": "code_interpreter", + "description": "interpreter for code, which use for executing python code", + "parameters": { + "type": "string", + "properties": { + "code": { + "type": "string", + "description": "the code to be executed, which Enclose the code within triple backticks (```) at the beginning and end of the code.", + } + }, + "required": ["code"], + }, + } +) + +GET_RIVER_ENVIRONMENT = FunctionAvailable( + function={ + "name": "get_river_environment", + "description": "get river environment at a specific location", + "parameters": { + "type": "string", + "properties": { + "river": { + "type": "string", + "description": "the river name to be queried", + }, + "location": {"type": "string", "description": "the location of the river environment to be queried"}, + }, + "required": ["river", "location"], + }, + } +) + +GET_ENVIRONMENT_AIR = FunctionAvailable( + function={ + "name": "get_environment_air", + "description": "get environment air at a specific location", + "parameters": { + "type": "string", + "properties": { + "location": {"type": "string", "description": "the location of the environment air to be queried"} + }, + "required": ["location"], + }, + } +) + + +def parse_function_availables(question: str) -> List[FunctionAvailable]: + tools = [] + if "/code" in question: + question = question.replace("/code", "") + tools.append(CI) + if "/get_river_environment" in question: + question = question.replace("/get_river_environment", "") + tools.append(GET_RIVER_ENVIRONMENT) + if "/get_environment_air" in question: + question = question.replace("/get_environment_air", "") + tools.append(GET_ENVIRONMENT_AIR) + return question, tools + + +def execute_function_call(function_call: FunctionCall): + response = None + if function_call.name == "code_interpreter": + response = CodeInterpreter().code_interpreter(**function_call.arguments) + if function_call.name == "get_river_environment": + response = str(function_call.arguments) + if function_call.name == "get_environment_air": + response = str(function_call.arguments) + return response diff --git a/src/cardinal/service/app.py b/src/cardinal/service/app.py index d513658..ecfa490 100644 --- a/src/cardinal/service/app.py +++ b/src/cardinal/service/app.py @@ -3,7 +3,7 @@ import uvicorn from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware -from sse_starlette import EventSourceResponse +from sse_starlette import EventSourceResponse, ServerSentEvent from ..app import KBQA from .protocol import ChatCompletionRequest, ChatCompletionResponse @@ -26,6 +26,10 @@ def predict(): yield "[DONE]" - return EventSourceResponse(predict(), media_type="text/event-stream") + return EventSourceResponse( + predict(), + media_type="text/event-stream", + ping_message_factory=lambda: ServerSentEvent(**{"comment": "You can't see\\r\\nthis ping"}), + ) - uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) + uvicorn.run(app, host="0.0.0.0", port=8020, workers=1)