Skip to content

Commit

Permalink
feat(llm):
Browse files Browse the repository at this point in the history
- change schedule time to every hour
- fix bugs
  • Loading branch information
MorvanZhou committed Jul 16, 2024
1 parent d679a90 commit 36a0a14
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/retk/core/ai/llm/knowledge/extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
async def get_extended_nodes(
uid: str,
) -> List[ExtendedNode]:
docs = await client.coll.llm_extended_node.find({"uid": uid}).to_list(None)
docs = await client.coll.llm_extended_node.find({"uid": uid}).sort("_id", -1).to_list(None)
return docs


Expand Down
2 changes: 2 additions & 0 deletions src/retk/core/ai/llm/knowledge/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Tuple

from retk import const
from retk.logger import logger
from .utils import parse_json_pattern, remove_links
from ..api.base import BaseLLMService, MessagesType

Expand Down Expand Up @@ -58,5 +59,6 @@ async def extend(
try:
title, content = parse_json_pattern(msg)
except ValueError as e:
logger.error(f"parse_json_pattern error: {e}. msg: {msg}")
return str(e), const.CodeEnum.LLM_INVALID_RESPONSE_FORMAT
return f"{title}\n\n{content}", const.CodeEnum.OK
8 changes: 5 additions & 3 deletions src/retk/core/ai/llm/knowledge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import re
from typing import Tuple

JSON_PTN = re.compile(r"^({\s*?\"title\":\s?\".+?\",\s*?\"content\":\s?\".+?\"\s*?})", re.DOTALL | re.MULTILINE)
JSON_PTN = re.compile(r"^{\s*?\"title\":\s?\"(.+?)\",\s*?\"content\":\s?\"(.+?)\"\s*?}", re.DOTALL | re.MULTILINE)
IMG_PTN = re.compile(r"!\[.*?\]\(.+?\)")
LINK_PTN = re.compile(r"\[(.*?)]\(.+?\)")


def parse_json_pattern(text: str) -> Tuple[str, str]:
m = JSON_PTN.search(text)
if m:
json_str = m.group(1)
d = json.loads(json_str)
title, content = m.group(1), m.group(2)
title = title.replace("\n", "\\n")
content = content.replace("\n", "\\n")
d = json.loads(f'{{"title": "{title}", "content": "{content}"}}')
return d["title"], d["content"]
raise ValueError(f"Invalid JSON pattern: {text}")

Expand Down
2 changes: 1 addition & 1 deletion src/retk/core/scheduler/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def init_tasks():
run_every_at(
job_id="deliver_unscheduled_node_extend",
func=tasks.extend_node.deliver_unscheduled_extend_nodes,
second=0,
minute=0,
)
return

Expand Down
6 changes: 3 additions & 3 deletions src/retk/core/scheduler/tasks/extend_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ async def async_deliver_unscheduled_extend_nodes() -> str:
for item in batch:
req_id = "".join([str(random.randint(0, 9)) for _ in range(10)])
node = await db[CollNameEnum.nodes.value].find_one({"id": item["nid"]})
# md = md[:int(8000 * 1.8)]
md = node["md"][:int(5000 * 1.8)] # cut md
s0 = time.perf_counter()
_summary, code = await knowledge.summary(
llm_service=knowledge.LLM_SERVICES[item["summaryService"]],
model=item["summaryModel"],
md=node["md"],
md=md,
req_id=req_id,
)
s1 = time.perf_counter()
Expand All @@ -58,7 +58,7 @@ async def async_deliver_unscheduled_extend_nodes() -> str:
)
e1 = time.perf_counter()
if code != const.CodeEnum.OK:
logger.error(f"knowledge extend error: {code}")
logger.error(f"knowledge extend error: code={code}")
continue
oneline_e = _extended.replace('\n', '\\n')
logger.debug(f"extended: {oneline_e}")
Expand Down
6 changes: 5 additions & 1 deletion tests/test_ai_llm_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ async def test_extend(self):
print(f"{service.__class__.__name__} {model.name}\n{text}\n\n")

def test_json_pattern(self):
title, content = parse_json_pattern("""{"title": "tttt", "content": "cccc\n21\n2"}""")
self.assertEqual("tttt", title)
self.assertEqual("cccc\n21\n2", content)

cases = [
"""\
{
Expand All @@ -132,7 +136,7 @@ def test_json_pattern(self):
"content": "cccc"
}
23423saq1是当前
"""
""",
]
for case in cases:
case = dedent(case)
Expand Down

0 comments on commit 36a0a14

Please sign in to comment.