Skip to content

Commit

Permalink
更改ofa加载逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
NaivG authored Sep 8, 2024
1 parent 1df02dc commit baab12b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from Muice import Muice
from ws import QQBot
from ofa_image_process import ImageCaptioningPipeline

logging.basicConfig(format='[%(levelname)s] %(message)s', level=logging.INFO)

Expand All @@ -24,6 +23,7 @@
# OFA图像模型
enable_ofa_image = configs["enable_ofa_image"]
if enable_ofa_image:
from ofa_image_process import ImageCaptioningPipeline
ofa_image_model_name_or_path = configs["ofa_image_model_name_or_path"]
ImageCaptioningPipeline.load_model(ofa_image_model_name_or_path)

Expand Down
10 changes: 6 additions & 4 deletions ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from fastapi import FastAPI, WebSocket
from starlette.websockets import WebSocketDisconnect

from ofa_image_process import ImageCaptioningPipeline
from fish_speech_api import fish_speech_api
from Tools import divide_sentences
from Tools import process_at_message
Expand Down Expand Up @@ -85,6 +84,9 @@ def __init__(self, muice_app):

# 定义公共变量
self.is_at_message = False
if self.enable_ofa_image:
from ofa_image_process import ImageCaptioningPipeline
self.image_captioning_pipeline = ImageCaptioningPipeline()

@self.app.websocket("/ws/api")
async def websocket_endpoint(websocket: WebSocket):
Expand Down Expand Up @@ -180,7 +182,7 @@ async def processing_reply(self, data):
if data['message_type'] == 'private':
logging.info(f"收到QQ{sender_user_id}的消息:{message}")
if sender_user_id in self.trust_qq_list:
if is_image: message = await ImageCaptioningPipeline().generate_caption(image_url)
if is_image: message = await self.image_captioning_pipeline.generate_caption(image_url)
reply_message_list = await self.produce_reply(message, sender_user_id)
if reply_message_list:
logging.debug(f"回复list{reply_message_list}")
Expand Down Expand Up @@ -225,7 +227,7 @@ async def processing_reply(self, data):
if self.group_reply_only_to_trusted:
if sender_user_id in self.trust_qq_list:
if not is_reply_message(self.at_reply,self.reply_rate,self.is_at_message): logging.info(f"未达到消息回复率{self.reply_rate}%,不回复") ; return None
if is_image: message = await ImageCaptioningPipeline().generate_caption(image_url)
if is_image: message = await self.image_captioning_pipeline.generate_caption(image_url)
reply_message_list = await self.produce_group_reply(message, sender_user_id, group_id)
logging.debug(f"回复list{reply_message_list}")
if reply_message_list is None:
Expand All @@ -235,7 +237,7 @@ async def processing_reply(self, data):
return None
else:
if not is_reply_message(self.at_reply,self.reply_rate,self.is_at_message): logging.info(f"未达到消息回复率{self.reply_rate}%,不回复") ; return None
if is_image: message = await ImageCaptioningPipeline().generate_caption(image_url)
if is_image: message = await self.image_captioning_pipeline.generate_caption(image_url)
reply_message_list = await self.produce_group_reply(message, sender_user_id, group_id)
logging.debug(f"回复list{reply_message_list}")
if reply_message_list is None:
Expand Down

0 comments on commit baab12b

Please sign in to comment.