Skip to content

Commit

Permalink
Merge pull request #657 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: optional artifact when grouping slices
  • Loading branch information
zjgemi authored Sep 21, 2023
2 parents 3d02571 + ce9361d commit 9a5ab61
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/dflow/python/python_op_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,18 +529,23 @@ def render_script(self):
if self.slices is not None and self.slices.pool_size is not None:
sliced_inputs = self.slices.input_artifact + \
self.slices.input_parameter
if len(sliced_inputs) > 1:
script += " assert %s\n" % " == ".join(
["len(input['%s'])" % i for i in sliced_inputs])
script += " n_slices = len(input['%s'])\n" % \
sliced_inputs[0]
script += " n_slices = None\n"
for name in sliced_inputs:
# for optional artifact
script += " if input['%s'] is not None:\n" % name
script += " if n_slices is None:\n"
script += " n_slices = len(input['%s'])\n" % name
script += " else:\n"
script += " assert len(input['%s']) == n_slices\n" \
% name
script += " assert n_slices is not None\n"
script += " input_list = []\n"
script += " from copy import deepcopy\n"
script += " for i in range(n_slices):\n"
script += " input1 = deepcopy(input)\n"
for name in sliced_inputs:
script += " input1['%s'] = list(input['%s'])[i]\n" % (
name, name)
script += " input1['%s'] = list(input['%s'])[i] if "\
"input['%s'] is not None else None\n" % (name, name, name)
script += " input_list.append(input1)\n"
if self.slices.pool_size == 1:
script += " output_list = []\n"
Expand Down

0 comments on commit 9a5ab61

Please sign in to comment.