Skip to content

Commit

Permalink
feat(llm):
Browse files Browse the repository at this point in the history
- fix model passing
  • Loading branch information
MorvanZhou committed Jul 19, 2024
1 parent bf670b5 commit b7ad105
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 22 deletions.
47 changes: 34 additions & 13 deletions src/retk/core/ai/llm/knowledge/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ class ExtendCase:
_id: ObjectId
uid: str
nid: str
service: str
model: str
summary_service: str
summary_model: str
extend_service: str
extend_model: str
md: str
stripped_md: str = ""
summary: str = ""
Expand Down Expand Up @@ -56,16 +58,26 @@ async def _batch_send(
) -> List[ExtendCase]:
svr_group = {}
for case in cases:
if case.service not in svr_group:
svr_group[case.service] = {}
if case.model not in svr_group[case.service]:
svr_group[case.service][case.model] = {"case": [], "msgs": []}
if is_extend:
service = case.extend_service
model = case.extend_model
content = case.summary
else:
service = case.summary_service
model = case.summary_model
content = case.stripped_md

if service not in svr_group:
svr_group[service] = {}
if model not in svr_group[service]:
svr_group[service][model] = {"case": [], "msgs": []}
_m: MessagesType = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": case.summary if is_extend else case.stripped_md},
{"role": "user", "content": content},
]
svr_group[case.service][case.model]["case"].append(case)
svr_group[case.service][case.model]["msgs"].append(_m)
svr_group[service][model]["case"].append(case)
svr_group[service][model]["msgs"].append(_m)

for service, models in svr_group.items():
for model, model_cases in models.items():
llm_service = LLM_SERVICES_MAP[service]
Expand All @@ -85,10 +97,14 @@ async def _batch_send(
oneline_s = _text.replace('\n', '\\n')
phase = "extend" if is_extend else "summary"
logger.debug(
f"reqId={req_id} | knowledge {phase} | {case.service} {case.model} | response='{oneline_s}'"
f"reqId={req_id} | knowledge {phase} "
f"| {service} {model} | response='{oneline_s}'"
)
if code != const.CodeEnum.OK:
logger.error(f"reqId={req_id} | knowledge {phase} | {case.service} {case.model} | error: {code}")
logger.error(
f"reqId={req_id} | knowledge {phase} "
f"| {service} {model} | error: {code}"
)
return cases


Expand Down Expand Up @@ -122,8 +138,13 @@ async def batch_extend(
try:
title, content = parse_json_pattern(case.extend)
except ValueError as e:
oneline = case.extend.replace('\n', '\\n')
logger.error(f"reqId={req_id} | parse_json_pattern error: {e}. msg: {oneline}")
oneline_e = case.extend.replace('\n', '\\n')
oneline_s = case.summary.replace('\n', '\\n')
logger.error(
f"reqId={req_id} | {case.extend_service} {case.extend_model} "
f"| parse_json_pattern error: {e} "
f"| summary: {oneline_s} "
f"| extension: {oneline_e}")
case.extend_code = const.CodeEnum.LLM_INVALID_RESPONSE_FORMAT
else:
case.extend = f"{title}\n\n{content}"
Expand Down
4 changes: 2 additions & 2 deletions src/retk/core/ai/llm/knowledge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def get_title_content(m):
m = JSON_PTN2.search(text)
if m:
return get_title_content(m)

raise ValueError(f"Invalid JSON pattern: {text}")
oneline = text.replace("\n", "\\n")
raise ValueError(f"Invalid JSON pattern: {oneline}")


def remove_links(text: str) -> str:
Expand Down
6 changes: 4 additions & 2 deletions src/retk/core/scheduler/tasks/extend_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ async def async_deliver_unscheduled_extend_nodes() -> str:
_id=item["_id"],
uid=item["uid"],
nid=item["nid"],
service=item["summaryService"],
model=item["summaryModel"],
summary_service=item["summaryService"],
summary_model=item["summaryModel"],
extend_service=item["extendService"],
extend_model=item["extendModel"],
md=node["md"],
)
)
Expand Down
2 changes: 1 addition & 1 deletion src/retk/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def __process_auth_headers(
raise json_exception(
request_id=request_id,
code=const.CodeEnum.INVALID_AUTH if is_refresh_token else const.CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN,
log_msg="empty token",
log_msg="EmptyToken",
)
au = AuthedUser(
u=None,
Expand Down
12 changes: 8 additions & 4 deletions tests/test_ai_llm_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ async def test_summary(self):
_id=ObjectId(),
uid="testuid",
nid="testnid",
service=service,
model=model.value.key,
summary_service=service,
summary_model=model.value.key,
extend_service=service,
extend_model=model.value.key,
md=md,
) for md in md_source
]
Expand All @@ -116,8 +118,10 @@ async def test_extend(self):
_id=ObjectId(),
uid="testuid",
nid="testnid",
service=service,
model=model.value.key,
summary_service=service,
summary_model=model.value.key,
extend_service=service,
extend_model=model.value.key,
md=md,
summary=md
) for md in md_summary
Expand Down

0 comments on commit b7ad105

Please sign in to comment.