Skip to content

Commit

Permalink
🎨 Apply black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
snowykami committed Dec 13, 2024
1 parent 59e0871 commit a9938d3
Show file tree
Hide file tree
Showing 16 changed files with 362 additions and 83 deletions.
11 changes: 11 additions & 0 deletions nonebot_plugin_marshoai/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins
from .util import *


Expand Down Expand Up @@ -85,6 +86,7 @@ async def at_enable():

@driver.on_startup
async def _preload_tools():
"""启动钩子加载工具"""
tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools:
Expand All @@ -98,6 +100,15 @@ async def _preload_tools():
)


@driver.on_startup
async def _preload_plugins():
"""启动钩子加载插件"""
marshoai_plugin_dirs = config.marshoai_plugin_dirs
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
load_plugins(*marshoai_plugin_dirs)
logger.opt(colors=True).info(f"已加载 <c>{len(_plugins)}</c> 个小棉插件")


@add_usermsg_cmd.handle()
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
if msg := arg.extract_plain_text():
Expand Down
2 changes: 2 additions & 0 deletions nonebot_plugin_marshoai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class ConfigModel(BaseModel):
marshoai_tencent_secretid: str | None = None
marshoai_tencent_secretkey: str | None = None

marshoai_plugin_dirs: list[str] = []


yaml = YAML()

Expand Down
80 changes: 47 additions & 33 deletions nonebot_plugin_marshoai/deal_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
See the Mulan PSL v2 for more details.
"""

import asyncio
import time
from typing import Literal, Optional, Tuple

Expand All @@ -35,7 +36,7 @@ async def get_to_convert(
return False, "请勿直接调用母类"

@staticmethod
def channel_test() -> int:
async def channel_test() -> int:
return -1


Expand Down Expand Up @@ -90,21 +91,23 @@ async def get_to_convert(
return False, "未知错误"

@staticmethod
def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client:
async def channel_test() -> int:
async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
latex2png = (
client.get(
await client.get(
"http://www.latex2png.com{}"
+ client.post(
"http://www.latex2png.com/api/convert",
json={
"auth": {"user": "guest", "password": "guest"},
"latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n",
"resolution": 600,
"color": "000000",
},
+ (
await client.post(
"http://www.latex2png.com/api/convert",
json={
"auth": {"user": "guest", "password": "guest"},
"latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n",
"resolution": 600,
"color": "000000",
},
)
).json()["url"]
),
time.time_ns() - start_time,
Expand Down Expand Up @@ -156,12 +159,12 @@ async def get_to_convert(
return False, "未知错误"

@staticmethod
def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client:
async def channel_test() -> int:
async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
codecogs = (
client.get(
await client.get(
r"https://latex.codecogs.com/png.image?\huge%20\dpi{600}\\int_{a}^{b}x^2\\,dx=\\frac{b^3}{3}-\\frac{a^3}{5}"
),
time.time_ns() - start_time,
Expand Down Expand Up @@ -223,19 +226,21 @@ async def get_to_convert(
return False, "未知错误"

@staticmethod
def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client:
async def channel_test() -> int:
async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
joeraut = (
client.get(
client.post(
"http://www.latex2png.com/api/convert",
json={
"latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}",
"outputFormat": "PNG",
"outputScale": "1000%",
},
await client.get(
(
await client.post(
"http://www.latex2png.com/api/convert",
json={
"latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}",
"outputFormat": "PNG",
"outputScale": "1000%",
},
)
).json()["imageUrl"]
),
time.time_ns() - start_time,
Expand All @@ -255,11 +260,14 @@ class ConvertLatex:

channel: ConvertChannel

def __init__(self, channel: Optional[ConvertChannel] = None) -> None:
def __init__(self, channel: Optional[ConvertChannel] = None):
logger.info("LaTeX 转换服务将在 Bot 连接时异步加载")

async def load_channel(self, channel: ConvertChannel | None = None) -> None:
if channel is None:
logger.info("正在选择 LaTeX 转换服务频道,请稍等...")
self.channel = self.auto_choose_channel()
self.channel = await self.auto_choose_channel()
logger.info(f"已选择 {self.channel.__class__.__name__} 服务频道")
else:
self.channel = channel

Expand Down Expand Up @@ -297,9 +305,15 @@ async def generate_png(
)

@staticmethod
def auto_choose_channel() -> ConvertChannel:

return min(
channel_list,
key=lambda channel: channel.channel_test(),
)()
async def auto_choose_channel() -> ConvertChannel:
async def channel_test_wrapper(
channel: type[ConvertChannel],
) -> Tuple[int, type[ConvertChannel]]:
score = await channel.channel_test()
return score, channel

results = await asyncio.gather(
*(channel_test_wrapper(channel) for channel in channel_list)
)
best_channel = min(results, key=lambda x: x[0])[1]
return best_channel()
7 changes: 7 additions & 0 deletions nonebot_plugin_marshoai/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""该功能目前正在开发中,暂时不可用,受影响的文件夹 `plugin`, `plugins`
"""

from .load import *
from .models import *
from .register import *
from .utils import *
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@
]


def get_plugin(name: str) -> Plugin | None:
"""获取插件对象
Args:
name: 插件名称
Returns:
Optional[Plugin]: 插件对象
"""
return _plugins.get(name)


def get_plugins() -> dict[str, Plugin]:
"""获取所有插件
Returns:
dict[str, Plugin]: 插件集合
"""
return _plugins


def load_plugin(module_path: str | Path) -> Optional[Plugin]:
"""加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。
该函数产生的副作用在于将插件加载到 `_plugins` 中。
Expand All @@ -45,20 +65,23 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]:
module=module,
module_name=module_path,
)
_plugins[plugin.name] = plugin

plugin.metadata = getattr(module, "__marsho_meta__", None)

_plugins[plugin.name] = plugin
if plugin.metadata is None:
logger.opt(colors=True).warning(
f"成功加载小棉插件 <y>{plugin.name}</y>, 但是没有定义元数据"
)
else:
logger.opt(colors=True).success(
f'成功加载小棉插件 <c>"{plugin.metadata.name}"</c>'
)

logger.opt(colors=True).success(
f'Succeeded to load liteyuki plugin "{plugin.name}"'
)
return _plugins[module.__name__]
return plugin

except Exception as e:
logger.opt(colors=True).success(
f'Failed to load liteyuki plugin "<r>{module_path}</r>"'
)
logger.opt(colors=True).success(f'加载小棉插件失败 "<r>{module_path}</r>"')
traceback.print_exc()
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,6 @@
from pydantic import BaseModel


class Plugin(BaseModel):
"""
存储插件信息
Attributes:
----------
name: str
包名称 例如marsho_test
module: ModuleType
插件模块对象
module_name: str
点分割模块路径 例如a.b.c
metadata: "PluginMeta" | None
"""

name: str
"""包名称 例如marsho_test"""
module: ModuleType
"""插件模块对象"""
module_name: str
"""点分割模块路径 例如a.b.c"""
metadata: "PluginMetadata" | None = None
"""元"""


class PluginMetadata(BaseModel):
"""
Marsho 插件 对象元数据
Expand Down Expand Up @@ -58,3 +32,38 @@ class PluginMetadata(BaseModel):
author: str = ""
homepage: str = ""
extra: dict[str, Any] = {}


class Plugin(BaseModel):
"""
存储插件信息
Attributes:
----------
name: str
包名称 例如marsho_test
module: ModuleType
插件模块对象
module_name: str
点分割模块路径 例如a.b.c
metadata: "PluginMeta" | None
"""

name: str
"""包名称 例如marsho_test"""
module: ModuleType
"""插件模块对象"""
module_name: str
"""点分割模块路径 例如a.b.c"""
metadata: PluginMetadata | None = None
"""元"""

class Config:
arbitrary_types_allowed = True

def __hash__(self) -> int:
return hash(self.name)

def __eq__(self, other: Any) -> bool:
return self.name == other.name
55 changes: 55 additions & 0 deletions nonebot_plugin_marshoai/plugin/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""此模块用于获取function call中函数定义信息以及注册函数
"""

import inspect
from typing import Any, Callable, Coroutine, TypeAlias

import nonebot

from .utils import is_coroutine_callable

SYNC_FUNCTION_CALL: TypeAlias = Callable[..., str]
ASYNC_FUNCTION_CALL: TypeAlias = Callable[..., Coroutine[str, Any, str]]
FUNCTION_CALL: TypeAlias = SYNC_FUNCTION_CALL | ASYNC_FUNCTION_CALL

_loaded_functions: dict[str, FUNCTION_CALL] = {}


def async_wrapper(func: SYNC_FUNCTION_CALL) -> ASYNC_FUNCTION_CALL:
"""将同步函数包装为异步函数,但是不会真正异步执行,仅用于统一调用及函数签名
Args:
func: 同步函数
Returns:
ASYNC_FUNCTION_CALL: 异步函数
"""

async def wrapper(*args, **kwargs) -> str:
return func(*args, **kwargs)

return wrapper


def function_call(*funcs: FUNCTION_CALL):
"""返回一个装饰器,装饰一个函数, 使其注册为一个可被AI调用的function call函数
Args:
func: 函数对象,要有完整的 Google Style Docstring
Returns:
str: 函数定义信息
"""
for func in funcs:
if module := inspect.getmodule(func):
module_name = module.__name__ + "."
else:
module_name = ""
name = func.__name__
if not is_coroutine_callable(func):
func = async_wrapper(func) # type: ignore

_loaded_functions[name] = func
nonebot.logger.opt(colors=True).info(
f"加载 function call: <c>{module_name}{name}</c>"
)
Loading

0 comments on commit a9938d3

Please sign in to comment.