Skip to content

Commit

Permalink
Merge pull request #870 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: add onExit hook
  • Loading branch information
zjgemi authored Oct 22, 2024
2 parents 143d9b0 + acbd12d commit b57edb3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
22 changes: 13 additions & 9 deletions src/dflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,17 @@ def argo_concat(
return ArgoConcat(param)


def upload_python_packages(python_packages):
hit = list(filter(lambda x: x[0] == python_packages,
uploaded_python_packages))
if len(hit) > 0:
return hit[0][1]
else:
artifact = upload_artifact(python_packages)
uploaded_python_packages.append((python_packages, artifact))
return artifact


class Step:
"""
Step
Expand Down Expand Up @@ -400,15 +411,8 @@ def __init__(

if hasattr(self.template, "python_packages") and \
self.template.python_packages:
hit = list(filter(lambda x: x[0] == self.template.python_packages,
uploaded_python_packages))
if len(hit) > 0:
self.set_artifacts({"dflow_python_packages": hit[0][1]})
else:
artifact = upload_artifact(self.template.python_packages)
self.set_artifacts({"dflow_python_packages": artifact})
uploaded_python_packages.append(
(self.template.python_packages, artifact))
artifact = upload_python_packages(self.template.python_packages)
self.set_artifacts({"dflow_python_packages": artifact})

if self.key is not None:
self.template.inputs.parameters["dflow_key"] = InputParameter(
Expand Down
15 changes: 13 additions & 2 deletions src/dflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .executor import Executor
from .op_template import (ContainerOPTemplate, OPTemplate, ScriptOPTemplate,
get_k8s_client)
from .step import Step
from .step import Step, upload_python_packages
from .steps import Steps
from .task import Task
from .utils import copy_s3, get_key, linktree, randstr, set_key
Expand Down Expand Up @@ -112,6 +112,7 @@ def __init__(
str, DockerSecret, List[Union[str, DockerSecret]]]] = None,
artifact_repo_key: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
on_exit: Optional[OPTemplate] = None,
) -> None:
self.host = host if host is not None else config["host"]
self.token = token if token is not None else config["token"]
Expand Down Expand Up @@ -166,6 +167,7 @@ def __init__(
parse_repo(self.artifact_repo_key, self.namespace,
k8s_api_server=self.k8s_api_server, token=self.token,
k8s_config_file=self.k8s_config_file)
self.on_exit = on_exit

def get_k8s_core_v1_api(self):
if self.k8s_client is None:
Expand Down Expand Up @@ -607,6 +609,14 @@ def convert_to_argo(self, reuse_step=None):
workflow_urn = config["lineage"].register_workflow(self.name)
self.parameters["dflow_workflow_urn"] = workflow_urn

if self.on_exit is not None:
if hasattr(self.on_exit, "python_packages") and \
self.on_exit.python_packages:
artifact = upload_python_packages(self.on_exit.python_packages)
self.on_exit.inputs.artifacts[
"dflow_python_packages"].source = artifact
self.handle_template(self.on_exit)

self.deduplicate_templates()
return V1alpha1Workflow(
metadata=metadata,
Expand All @@ -624,7 +634,8 @@ def convert_to_argo(self, reuse_step=None):
pod_gc=V1alpha1PodGC(strategy=self.pod_gc_strategy),
image_pull_secrets=self.image_pull_secrets,
artifact_repository_ref=None if self.artifact_repo_key is None
else V1alpha1ArtifactRepositoryRef(key=self.artifact_repo_key)
else V1alpha1ArtifactRepositoryRef(key=self.artifact_repo_key),
on_exit=self.on_exit.name if self.on_exit is not None else None
),
status={"outputs": {"parameters": list(global_parameters.values()),
"artifacts": list(global_artifacts.values())}})
Expand Down

0 comments on commit b57edb3

Please sign in to comment.