Skip to content

Commit

Permalink
refactor: handle the failed task message (#600)
Browse files Browse the repository at this point in the history
* refactor: handle the failed task message

* chore: release petercat-utils/0.1.41
  • Loading branch information
xingwanying authored Dec 24, 2024
1 parent 04e6ea7 commit 6605804
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 70 deletions.
2 changes: 2 additions & 0 deletions petercat_utils/rag_helper/git_doc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ def __init__(
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None,
retry_count=0,
):
super().__init__(
type=TaskType.GIT_DOC,
from_id=from_id,
id=id,
status=status,
repo_name=repo_name,
retry_count=retry_count,
)
self.commit_id = commit_id
self.node_type = GitDocTaskNodeType(node_type)
Expand Down
66 changes: 38 additions & 28 deletions petercat_utils/rag_helper/git_issue_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def add_rag_git_issue_task(config: RAGGitIssueConfig):
g.get_repo(config.repo_name)

issue_task = GitIssueTask(
issue_id='',
issue_id="",
node_type=GitIssueTaskNodeType.REPO,
bot_id=config.bot_id,
repo_name=config.repo_name
repo_name=config.repo_name,
)
res = issue_task.save()
issue_task.send()
Expand All @@ -26,17 +26,26 @@ class GitIssueTask(GitTask):
issue_id: str
node_type: GitIssueTaskNodeType

def __init__(self,
issue_id,
node_type: GitIssueTaskNodeType,
bot_id,
repo_name,
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None
):
super().__init__(bot_id=bot_id, type=TaskType.GIT_ISSUE, from_id=from_id, id=id, status=status,
repo_name=repo_name)
def __init__(
self,
issue_id,
node_type: GitIssueTaskNodeType,
bot_id,
repo_name,
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None,
retry_count=0,
):
super().__init__(
bot_id=bot_id,
type=TaskType.GIT_ISSUE,
from_id=from_id,
id=id,
status=status,
repo_name=repo_name,
retry_count=retry_count,
)
self.issue_id = issue_id
self.node_type = GitIssueTaskNodeType(node_type)

Expand Down Expand Up @@ -75,27 +84,28 @@ def handle_repo_node(self):
if len(task_list) > 0:
result = self.get_table().insert(task_list).execute()
for record in result.data:
issue_task = GitIssueTask(id=record["id"],
issue_id=record["issue_id"],
repo_name=record["repo_name"],
node_type=record["node_type"],
bot_id=record["bot_id"],
status=record["status"],
from_id=record["from_task_id"]
)
issue_task = GitIssueTask(
id=record["id"],
issue_id=record["issue_id"],
repo_name=record["repo_name"],
node_type=record["node_type"],
bot_id=record["bot_id"],
status=record["status"],
from_id=record["from_task_id"],
)
issue_task.send()

return (self.get_table().update(
{"status": TaskStatus.COMPLETED.value})
.eq("id", self.id)
.execute())
return (
self.get_table()
.update({"status": TaskStatus.COMPLETED.value})
.eq("id", self.id)
.execute()
)

def handle_issue_node(self):
issue_retrieval.add_knowledge_by_issue(
RAGGitIssueConfig(
repo_name=self.repo_name,
bot_id=self.bot_id,
issue_id=self.issue_id
repo_name=self.repo_name, bot_id=self.bot_id, issue_id=self.issue_id
)
)
return self.update_status(TaskStatus.COMPLETED)
12 changes: 10 additions & 2 deletions petercat_utils/rag_helper/git_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ def __init__(
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None,
retry_count=0,
):
self.type = type
self.id = id
self.from_id = from_id
self.status = status
self.repo_name = repo_name
self.retry_count = retry_count

@staticmethod
def get_table_name(type: TaskType):
Expand Down Expand Up @@ -82,11 +84,17 @@ def send(self):
QueueUrl=SQS_QUEUE_URL,
DelaySeconds=10,
MessageBody=(
json.dumps({"task_id": self.id, "task_type": self.type.value})
json.dumps(
{
"task_id": self.id,
"task_type": self.type.value,
"retry_count": self.retry_count,
}
)
),
)
message_id = response["MessageId"]
print(
f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}"
f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}, retry_count={self.retry_count}"
)
return message_id
20 changes: 5 additions & 15 deletions petercat_utils/rag_helper/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@
SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL")


def send_task_message(task_id: str):
response = sqs.send_message(
QueueUrl=SQS_QUEUE_URL,
DelaySeconds=10,
MessageBody=(json.dumps({"task_id": task_id})),
)
return response["MessageId"]


def get_oldest_task():
supabase = get_client()

Expand All @@ -54,10 +45,7 @@ def get_task_by_id(task_id):
return response.data[0] if (len(response.data) > 0) else None


def get_task(
task_type: TaskType,
task_id: str,
) -> GitTask:
def get_task(task_type: TaskType, task_id: str, retry_count=0) -> GitTask:
supabase = get_client()
response = (
supabase.table(GitTask.get_table_name(task_type))
Expand All @@ -77,6 +65,7 @@ def get_task(
path=data["path"],
status=data["status"],
from_id=data["from_task_id"],
retry_count=retry_count,
)
if task_type == TaskType.GIT_ISSUE:
return GitIssueTask(
Expand All @@ -87,11 +76,12 @@ def get_task(
bot_id=data["bot_id"],
status=data["status"],
from_id=data["from_task_id"],
retry_count=retry_count,
)


def trigger_task(task_type: TaskType, task_id: Optional[str]):
task = get_task(task_type, task_id) if task_id else get_oldest_task()
def trigger_task(task_type: TaskType, task_id: Optional[str], retry_count: int = 0):
task = get_task(task_type, task_id, retry_count) if task_id else get_oldest_task()
if task is None:
return task
return task.handle()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "petercat_utils"
version = "0.1.40"
version = "0.1.41"
description = ""
authors = ["raoha.rh <[email protected]>"]
readme = "README.md"
Expand Down
29 changes: 16 additions & 13 deletions server/aws/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,32 @@
STATIC_SECRET_NAME = get_env_variable("STATIC_SECRET_NAME")
STATIC_KEYPAIR_ID = get_env_variable("STATIC_KEYPAIR_ID")


def rsa_signer(message):
private_key_str = get_private_key(STATIC_SECRET_NAME)
private_key = rsa.PrivateKey.load_pkcs1(private_key_str.encode('utf-8'))
return rsa.sign(message, private_key, 'SHA-1')
private_key = rsa.PrivateKey.load_pkcs1(private_key_str.encode("utf-8"))
return rsa.sign(message, private_key, "SHA-1")


def create_signed_url(url, expire_minutes=60) -> str:
cloudfront_signer = CloudFrontSigner(STATIC_KEYPAIR_ID, rsa_signer)

# 设置过期时间
expire_date = datetime.now() + timedelta(minutes=expire_minutes)

# 创建签名 URL
signed_url = cloudfront_signer.generate_presigned_url(
url=url,
date_less_than=expire_date
url=url, date_less_than=expire_date
)

return signed_url


def upload_image_to_s3(file, metadata: ImageMetaData, s3_client):
try:
file_content = file.file.read()
md5_hash = hashlib.md5()
md5_hash.update(file.filename.encode('utf-8'))
md5_hash.update(file.filename.encode("utf-8"))
s3_key = md5_hash.hexdigest()
encoded_filename = (
base64.b64encode(metadata.title.encode("utf-8")).decode("utf-8")
Expand All @@ -62,11 +64,12 @@ def upload_image_to_s3(file, metadata: ImageMetaData, s3_client):
ContentType=file.content_type,
Metadata=custom_metadata,
)
# you need to redirect your static domain to your s3 bucket domain
s3_url = f"{STATIC_URL}/{s3_key}"
signed_url = create_signed_url(url=s3_url, expire_minutes=60) \
if (STATIC_SECRET_NAME and STATIC_KEYPAIR_ID) \
else s3_url
return {"message": "File uploaded successfully", "url": signed_url }
signed_url = (
create_signed_url(url=s3_url, expire_minutes=60)
if (STATIC_SECRET_NAME and STATIC_KEYPAIR_ID)
else s3_url
)
return {"message": "File uploaded successfully", "url": signed_url}
except Exception as e:
raise UploadError(detail=str(e))
30 changes: 19 additions & 11 deletions subscriber/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,38 @@
from petercat_utils import task as task_helper
from petercat_utils.data_class import TaskType

MAX_RETRY_COUNT = 5


def lambda_handler(event, context):
if event:
batch_item_failures = []
sqs_batch_response = {}

for record in event["Records"]:
try:
body = record["body"]
print(f"receive message here: {body}")
body = record["body"]
print(f"receive message here: {body}")

message_dict = json.loads(body)
task_id = message_dict["task_id"]
task_type = message_dict["task_type"]
task = task_helper.get_task(TaskType(task_type), task_id)
message_dict = json.loads(body)
task_id = message_dict["task_id"]
task_type = message_dict["task_type"]
retry_count = message_dict["retry_count"]
task = task_helper.get_task(TaskType(task_type), task_id)
try:
if task is None:
return task
task.handle()

# process message
print(f"message content: message={message_dict}, task_id={task_id}, task={task}")
print(
f"message content: message={message_dict}, task_id={task_id}, task={task}, retry_count={retry_count}"
)
except Exception as e:
print(f"message handle error: ${e}")
batch_item_failures.append({"itemIdentifier": record['messageId']})
if retry_count < MAX_RETRY_COUNT:
retry_count += 1
task_helper.trigger_task(task_type, task_id, retry_count)
else:
print(f"message handle error: ${e}")
batch_item_failures.append({"itemIdentifier": record["messageId"]})

sqs_batch_response["batchItemFailures"] = batch_item_failures
return sqs_batch_response

0 comments on commit 6605804

Please sign in to comment.