Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat](function): add agent functions #1

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/cardinal/app/kbqa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import TYPE_CHECKING, Generator, List

from ..core.function_calls.functions import execute_function_call, parse_function_availables
from ..core.model import ChatOpenAI, EmbedOpenAI
from ..core.retriever import BaseRetriever
from ..core.schema import Leaf, LeafIndex, Template
Expand Down Expand Up @@ -37,5 +38,15 @@ def __call__(self, messages: List["BaseMessage"]) -> Generator[str, None, None]:
else:
question = self._plain_template.apply(question=question)

# parse function availables
question, tools = parse_function_availables(question)

# execute function calls
if tools:
function_call = self._chat_model.function_call(messages=messages, tools=tools)
# calling the function_calls, and get the response
response = execute_function_call(function_call)
yield response

messages[-1].content = question
yield from self._chat_model.stream_chat(messages=messages)
91 changes: 91 additions & 0 deletions src/cardinal/core/function_calls/code_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
import os
import re

import json5
import matplotlib

from ..function_calls.code_kernel import CodeKernel


START_CODE = """
import signal
def _m6_code_interpreter_timeout_handler(signum, frame):
raise TimeoutError("M6_CODE_INTERPRETER_TIMEOUT")
signal.signal(signal.SIGALRM, _m6_code_interpreter_timeout_handler)

def input(*args, **kwargs):
raise NotImplementedError('Python input() function is disabled.')

import math
import re
import json

import seaborn as sns
sns.set_theme()

import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

import numpy as np
import pandas as pd

from sympy import Eq, symbols, solve
"""


class CodeInterpreter:
def __init__(self):
self.code_kernel = CodeKernel()

# ensure Chinese font support, before launching app
def fix_matplotlib_cjk_font_issue(self):
local_ttf = os.path.join(
os.path.abspath(os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)), "fonts", "ttf", "simhei.ttf"
)
if not os.path.exists(local_ttf):
logging.warning(
f"Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib."
)

# extract code from text
def extract_code(self, text):
# Match triple backtick blocks first
triple_match = re.search(r"```[^\n]*\n(.+?)```", text, re.DOTALL)
# Match single backtick blocks second
single_match = re.search(r"`([^`]*)`", text, re.DOTALL)
if triple_match:
text = triple_match.group(1)
elif single_match:
text = single_match.group(1)
else:
try:
text = json5.loads(text)["code"]
except Exception:
pass
# If no code blocks found, return original text
return text

def code_interpreter(self, code: str, timeout=30, start_code=True):
# convert action_input_list to code
code = self.extract_code(code) + "\n"
# add timeout
if timeout:
code = f"signal.alarm({timeout})\n{code}"
# add start code
if start_code:
code = START_CODE + "\n" + code
res_type, res = self.code_kernel.execute(code)
return res if res_type == "image" else None


if __name__ == "__main__":
ci = CodeInterpreter()
image_b64 = ci.code_interpreter(
[
"```py\nimport matplotlib.pyplot as plt\n\n# 数据\ncountries = ['英军', '美军', '德军']\ntroops = [300000, 700000, 500000]\n\n# 绘制柱状图\nplt.bar(countries, troops)\nplt.xlabel('国家')\nplt.ylabel('军力(万)')\nplt.title('各国军力')\nplt.show()\n```\n"
]
)
print("res:\n", image_b64)
163 changes: 163 additions & 0 deletions src/cardinal/core/function_calls/code_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
import queue
import re
from pprint import pprint
from subprocess import PIPE

import jupyter_client
from PIL import Image


# make sure run "ipython kernel install --name code_kernel --user" before using code kernel
IPYKERNEL = os.environ.get("IPYKERNEL", "code_kernel")


class CodeKernel(object):
def __init__(
self,
kernel_name="kernel",
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1,
):
self.kernel_name = kernel_name
self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path
self.python_path = python_path
self.ipython_path = ipython_path
self.init_file_path = init_file_path
self.verbose = verbose

if python_path is None and ipython_path is None:
env = None
else:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}

# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(
kernel_name=IPYKERNEL, connection_file=self.kernel_config_path, exec_files=[self.init_file_path], env=env
)
if self.kernel_config_path:
self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(self.kernel_config_path))
else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(self.kernel_manager.connection_file))

if verbose:
pprint(self.kernel_manager.get_connection_info())

# Initialize the code kernel
self.kernel = self.kernel_manager.blocking_client()
# self.kernel.load_connection_file()
self.kernel.start_channels()
print("Code kernel started.")

def _execute(self, code):
self.kernel.execute(code)
try:
shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"]
while True:
msg_out = io_msg_content
# Poll the message
try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"]
if "execution_state" in io_msg_content and io_msg_content["execution_state"] == "idle":
break
except queue.Empty:
break

return shell_msg, msg_out
except Exception as e:
print(e)
return None, None

def clean_ansi_codes(self, input_string):
ansi_escape = re.compile(r"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]")
return ansi_escape.sub("", input_string)

def execute(self, code) -> tuple[str, str | Image.Image]:
res = ""
res_type = None
code = code.replace("<|observation|>", "")
code = code.replace("<|assistant|>interpreter", "")
code = code.replace("<|assistant|>", "")
code = code.replace("<|user|>", "")
code = code.replace("<|system|>", "")
msg, output = self._execute(code)
if not msg and not output:
return res_type, "no output"

if msg["metadata"]["status"] == "timeout":
return res_type, "Timed out"
elif msg["metadata"]["status"] == "error":
return res_type, self.clean_ansi_codes("\n".join(self.get_error_msg(msg, verbose=True)))

if "text" in output:
res_type = "text"
res = output["text"]
elif "data" in output:
for key in output["data"]:
if "text/plain" in key:
res_type = "text"
res = output["data"][key]
elif "image/png" in key:
res_type = "image"
res = output["data"][key]
break

if res_type == "text" or res_type == "traceback":
res = res

return res_type, res

def get_error_msg(self, msg, verbose=False) -> str | None:
if msg["content"]["status"] == "error":
try:
error_msg = msg["content"]["traceback"]
except Exception:
try:
error_msg = msg["content"]["traceback"][-1].strip()
except Exception:
error_msg = "Traceback Error"
if verbose:
print("Error: ", error_msg)
return error_msg
return None

def shutdown(self):
# Shutdown the backend kernel
self.kernel_manager.shutdown_kernel()
print("Backend kernel shutdown.")
# Shutdown the code kernel
self.kernel.shutdown()
print("Code kernel shutdown.")

def restart(self):
# Restart the backend kernel
self.kernel_manager.restart_kernel()
# print("Backend kernel restarted.")

def interrupt(self):
# Interrupt the backend kernel
self.kernel_manager.interrupt_kernel()
# print("Backend kernel interrupted.")

def is_alive(self):
return self.kernel.is_alive()


if __name__ == "__main__":
ck = CodeKernel()
# shell_msg, msg_out = ck.execute("print('hello world')")
res_type, res = ck.execute(
"""import matplotlib.pyplot as plt\n\n# 数据\ncountries = ['英军', '美军', '德军']\ntroops = [300000, 700000, 500000]\n\n# 绘制柱状图\nplt.bar(countries, troops)\nplt.xlabel('国家')\nplt.ylabel('军力(万)')\nplt.title('各国军力')\nplt.show()""",
CodeKernel(),
)
print("res_type:\n", res_type, "\n")
print("res:\n", res, "\n")
81 changes: 81 additions & 0 deletions src/cardinal/core/function_calls/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List

from ..function_calls.code_interpreter import CodeInterpreter
from ..schema import FunctionAvailable, FunctionCall


# function description

CI = FunctionAvailable(
function={
"name": "code_interpreter",
"description": "interpreter for code, which use for executing python code",
"parameters": {
"type": "string",
"properties": {
"code": {
"type": "string",
"description": "the code to be executed, which Enclose the code within triple backticks (```) at the beginning and end of the code.",
}
},
"required": ["code"],
},
}
)

GET_RIVER_ENVIRONMENT = FunctionAvailable(
function={
"name": "get_river_environment",
"description": "get river environment at a specific location",
"parameters": {
"type": "string",
"properties": {
"river": {
"type": "string",
"description": "the river name to be queried",
},
"location": {"type": "string", "description": "the location of the river environment to be queried"},
},
"required": ["river", "location"],
},
}
)

GET_ENVIRONMENT_AIR = FunctionAvailable(
function={
"name": "get_environment_air",
"description": "get environment air at a specific location",
"parameters": {
"type": "string",
"properties": {
"location": {"type": "string", "description": "the location of the environment air to be queried"}
},
"required": ["location"],
},
}
)


def parse_function_availables(question: str) -> List[FunctionAvailable]:
tools = []
if "/code" in question:
question = question.replace("/code", "")
tools.append(CI)
if "/get_river_environment" in question:
question = question.replace("/get_river_environment", "")
tools.append(GET_RIVER_ENVIRONMENT)
if "/get_environment_air" in question:
question = question.replace("/get_environment_air", "")
tools.append(GET_ENVIRONMENT_AIR)
return question, tools


def execute_function_call(function_call: FunctionCall):
response = None
if function_call.name == "code_interpreter":
response = CodeInterpreter().code_interpreter(**function_call.arguments)
if function_call.name == "get_river_environment":
response = str(function_call.arguments)
if function_call.name == "get_environment_air":
response = str(function_call.arguments)
return response
10 changes: 7 additions & 3 deletions src/cardinal/service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uvicorn
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
from sse_starlette import EventSourceResponse, ServerSentEvent

from ..app import KBQA
from .protocol import ChatCompletionRequest, ChatCompletionResponse
Expand All @@ -26,6 +26,10 @@ def predict():

yield "[DONE]"

return EventSourceResponse(predict(), media_type="text/event-stream")
return EventSourceResponse(
predict(),
media_type="text/event-stream",
ping_message_factory=lambda: ServerSentEvent(**{"comment": "You can't see\\r\\nthis ping"}),
)

uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
uvicorn.run(app, host="0.0.0.0", port=8020, workers=1)