diff --git a/src/dflow/python/op.py b/src/dflow/python/op.py index de86150e..ef75d76e 100644 --- a/src/dflow/python/op.py +++ b/src/dflow/python/op.py @@ -53,6 +53,7 @@ class OP(ABC): slices = {} tmp_root = "/tmp" create_slice_dir = False + pool_size = None def __init__( self, @@ -417,6 +418,11 @@ def handle_outputs(self, outputs, symlink=False): sign = output_sign[name] if isinstance(sign, Artifact): slices = self.slices.get(name) + if self.pool_size and slices is not None: + if sign.type == str: + sign.type = List[str] + elif sign.type == Path: + sign.type = List[Path] handle_output_artifact( name, outputs[name], sign, slices, self.tmp_root, self.create_slice_dir and slices, symlink=symlink) diff --git a/src/dflow/python/python_op_template.py b/src/dflow/python/python_op_template.py index 951195e6..6732c639 100644 --- a/src/dflow/python/python_op_template.py +++ b/src/dflow/python/python_op_template.py @@ -604,6 +604,8 @@ def render_script(self): else: slices = self.get_slices(output_parameter_slices, name) script += " op_obj.slices['%s'] = %s\n" % (name, slices) + script += " op_obj.pool_size = %s\n" % getattr( + self.slices, "pool_size", None) script += " import signal\n" script += " def sigterm_handler(signum, frame):\n" @@ -671,13 +673,6 @@ def render_script(self): for name in sliced_outputs: script += " output['%s'] = [o.get('%s') if o is not None"\ " else None for o in output_list]\n" % (name, name) - if isinstance(output_sign[name], Artifact): - if output_sign[name].type == str: - script += " output_sign['%s'].type = List[str]"\ - "\n" % name - elif output_sign[name].type == Path: - script += " output_sign['%s'].type = List[Path"\ - "]\n" % name else: script += " try:\n" script += " try:\n"