Skip to content

Commit

Permalink
Merge pull request #845 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: handle sub_path slices of artifact list for debug mode
  • Loading branch information
zjgemi authored Aug 6, 2024
2 parents d502bd4 + 0d9e1a0 commit 73d811b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
15 changes: 15 additions & 0 deletions src/dflow/python/python_op_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def __init__(self,
self.init_progress = "%s/%s" % (op_class.progress_current,
op_class.progress_total)
self.memoize_key = memoize_key
if hasattr(op_class, "func") and config["mode"] == "debug":
op_class = PicklableFunctionOP(op_class)
op = op_class

self.op_class = op_class
self.input_sign = input_sign
Expand Down Expand Up @@ -821,3 +824,15 @@ class TransientError(Exception):

class FatalError(Exception):
pass


class PicklableFunctionOP:
def __init__(self, op_class):
self.__module__ = op_class.__module__
self.__name__ = op_class.__name__
if hasattr(op_class, "_source"):
self._source = op_class._source
self.__module__ = "__main__"
elif self.__module__ in ["__main__", "__mp_main__"]:
self._source = get_source_code(op_class.func)
self.func = True
7 changes: 5 additions & 2 deletions src/dflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ def argo_enumerate(
if config["mode"] == "debug":
values = "".join([", '%s': %s[i]" % (k, to_expr(v))
for k, v in kwargs.items()])
return Expression("[{'order': i%s} for i in range(len(%s))]" % (
expr = Expression("[{'order': i%s} for i in range(len(%s))]" % (
values, to_expr(list(kwargs.values())[0])))
expr.kwargs = kwargs
return expr
return ArgoEnumerate(**kwargs)


Expand Down Expand Up @@ -1029,7 +1031,8 @@ def handle_sub_path_slices_of_artifact_list(self, slices, artifacts):
if isinstance(self.with_param, ArgoEnumerate):
self.with_param = argo_enumerate(**self.with_param.kwargs, **param)
else:
self.with_param = argo_enumerate(**param)
self.with_param = argo_enumerate(
**getattr(self.with_param, "kwargs", {}), **param)
slices.slices = "{{item.order}}"
slices.sub_path = False
slices.input_artifact = []
Expand Down

0 comments on commit 73d811b

Please sign in to comment.