From 28e855a52f49700bc2bce522d6303d99390d3d4a Mon Sep 17 00:00:00 2001 From: zjgemi Date: Mon, 2 Sep 2024 17:46:30 +0800 Subject: [PATCH] fix: add HDF5Datasets type Artifact Signed-off-by: zjgemi --- src/dflow/python/__init__.py | 5 ++- src/dflow/python/op.py | 4 +- src/dflow/python/opio.py | 68 ++++++++++++++++++++++++----- src/dflow/python/utils.py | 85 +++++++++++++++++++++++++++++++----- 4 files changed, 138 insertions(+), 24 deletions(-) diff --git a/src/dflow/python/__init__.py b/src/dflow/python/__init__.py index 085855ac..75b11c0f 100644 --- a/src/dflow/python/__init__.py +++ b/src/dflow/python/__init__.py @@ -1,8 +1,9 @@ from .op import OP -from .opio import OPIO, Artifact, BigParameter, OPIOSign, Parameter, NestedDict +from .opio import (OPIO, Artifact, BigParameter, HDF5Datasets, NestedDict, + OPIOSign, Parameter) from .python_op_template import (FatalError, PythonOPTemplate, Slices, TransientError, upload_packages) __all__ = ["OP", "OPIO", "Artifact", "BigParameter", "OPIOSign", "Parameter", "FatalError", "PythonOPTemplate", "Slices", "TransientError", - "upload_packages", "NestedDict"] + "upload_packages", "NestedDict", "HDF5Datasets"] diff --git a/src/dflow/python/op.py b/src/dflow/python/op.py index 0b1b9611..ff55175a 100644 --- a/src/dflow/python/op.py +++ b/src/dflow/python/op.py @@ -19,8 +19,8 @@ from ..io import (InputArtifact, InputParameter, OutputArtifact, OutputParameter, type_to_str) from ..utils import dict2list, get_key, randstr, s3_config -from .vendor.typeguard import check_type from .opio import OPIO, Artifact, BigParameter, OPIOSign, Parameter +from .vendor.typeguard import check_type iwd = os.getcwd() @@ -190,6 +190,8 @@ def _check_signature( ss = Set[Union[str, None]] elif ss == Set[Path]: ss = Set[Union[Path, None]] + else: + continue if isinstance(ss, Parameter): ss = ss.type # skip type checking if the variable is None diff --git a/src/dflow/python/opio.py b/src/dflow/python/opio.py index c5acff51..28c25e53 100644 --- a/src/dflow/python/opio.py +++ b/src/dflow/python/opio.py @@ -1,4 +1,5 @@ import json +import tarfile from collections.abc import MutableMapping from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union @@ -8,27 +9,72 @@ from ..io import PVC, type_to_str -class nested_dict: - def __init__(self, type): - self.type = type +class NestedDict: + pass + + +class NestedDictStr(NestedDict): + pass + + +class NestedDictPath(NestedDict): + pass + + +class HDF5Dataset: + def __init__(self, dataset): + self.dataset = dataset + + def get_data(self): + data = self.dataset[()] + if self.dataset.attrs.get("dtype") == "utf-8": + data = data.decode("utf-8") + elif self.dataset.attrs.get("dtype") == "binary": + data = data.tobytes() + return data + + def recover(self): + if self.dataset.attrs["type"] == "file": + path = Path(self.dataset.attrs["path"]) + if path.is_absolute(): + path = path.relative_to(path.root) + path.parent.mkdir(parents=True, exist_ok=True) + data = self.get_data() + if isinstance(data, str): + path.write_text(data) + elif isinstance(data, bytes): + path.write_bytes(data) + return path + elif self.dataset.attrs["type"] == "dir": + path = Path(self.dataset.attrs["path"]) + if path.is_absolute(): + path = path.relative_to(path.root) + path.parent.mkdir(parents=True, exist_ok=True) + tgz_path = path.parent / (path.name + ".tgz") + tgz_path.write_bytes(self.get_data()) + tf = tarfile.open(tgz_path, "r:gz") + tf.extractall(".") + tf.close() + return path + else: + return self.get_data() - def __repr__(self): - return "dflow.python.NestedDict[%s]" % type_to_str(self.type) - def __eq__(self, other): - if not isinstance(other, nested_dict): - return False - return self.type == other.type +class HDF5Datasets: + pass NestedDict = { - str: nested_dict(str), - Path: nested_dict(Path), + str: NestedDictStr, + Path: NestedDictPath, } ArtifactAllowedTypes = [str, Path, Set[str], Set[Path], List[str], List[Path], Dict[str, str], Dict[str, Path], NestedDict[str], NestedDict[Path]] +for t in ArtifactAllowedTypes.copy(): + ArtifactAllowedTypes.append(Union[t, HDF5Datasets]) +ArtifactAllowedTypes.append(HDF5Datasets) @CustomHandler.handles diff --git a/src/dflow/python/utils.py b/src/dflow/python/utils.py index c8f85269..ce880c3b 100644 --- a/src/dflow/python/utils.py +++ b/src/dflow/python/utils.py @@ -1,17 +1,19 @@ import os import shutil import signal +import tarfile import traceback import uuid from pathlib import Path -from typing import Dict, List, Set +from typing import Dict, List, Set, Union from ..common import jsonpickle from ..config import config from ..utils import (artifact_classes, assemble_path_object, catalog_of_local_artifact, convert_dflow_list, copy_file, expand, flatten, randstr, remove_empty_dir_tag) -from .opio import Artifact, BigParameter, NestedDict, Parameter +from .opio import (Artifact, BigParameter, HDF5Dataset, HDF5Datasets, + NestedDict, Parameter) def get_slices(path_object, slices): @@ -78,7 +80,35 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp", path_object = get_slices(path_object, slices) - if sign.type in [str, Path]: + sign_type = sign.type + if getattr(sign_type, "__origin__", None) == Union: + args = sign_type.__args__ + if HDF5Datasets in args: + if isinstance(path_object, list) and all([isinstance( + p, str) and p.endswith(".h5") for p in path_object]): + sign_type = HDF5Datasets + elif args[0] == HDF5Datasets: + sign_type = args[1] + elif args[1] == HDF5Datasets: + sign_type = args[0] + + if sign_type == HDF5Datasets: + import h5py + assert isinstance(path_object, list) + res = None + for path in path_object: + f = h5py.File(path, "r") + datasets = {k: HDF5Dataset(f[k]) for k in f.keys()} + datasets = expand(datasets) + if isinstance(datasets, list): + if res is None: + res = [] + res += datasets + elif isinstance(datasets, dict): + if res is None: + res = {} + res.update(datasets) + if sign_type in [str, Path]: if path_object is None or isinstance(path_object, str): res = path_object elif isinstance(path_object, list) and len(path_object) == 1 and ( @@ -87,8 +117,8 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp", res = path_object[0] else: res = art_path - res = path_or_none(res) if sign.type == Path else res - elif sign.type in [List[str], List[Path], Set[str], Set[Path]]: + res = path_or_none(res) if sign_type == Path else res + elif sign_type in [List[str], List[Path], Set[str], Set[Path]]: if path_object is None: return None elif isinstance(path_object, str): @@ -99,17 +129,17 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp", else: res = list(flatten(path_object).values()) - if sign.type == List[str]: + if sign_type == List[str]: pass - elif sign.type == List[Path]: + elif sign_type == List[Path]: res = path_or_none(res) - elif sign.type == Set[str]: + elif sign_type == Set[str]: res = set(res) else: res = set(path_or_none(res)) - elif sign.type in [Dict[str, str], NestedDict[str]]: + elif sign_type in [Dict[str, str], NestedDict[str]]: res = path_object - elif sign.type in [Dict[str, Path], NestedDict[Path]]: + elif sign_type in [Dict[str, Path], NestedDict[Path]]: res = path_or_none(path_object) if res is None: @@ -169,6 +199,41 @@ def slice_to_dir(slice): def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp", create_dir=False): path_list = [] + if sign.type == HDF5Datasets: + import h5py + os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True) + h5_name = "%s.h5" % uuid.uuid4() + h5_path = '%s/outputs/artifacts/%s/%s' % (data_root, name, h5_name) + with h5py.File(h5_path, "w") as f: + for s, v in flatten(value).items(): + if isinstance(v, Path): + if v.is_file(): + try: + data = v.read_text(encoding="utf-8") + dtype = "utf-8" + except Exception: + import numpy as np + data = np.void(v.read_bytes()) + dtype = "binary" + d = f.create_dataset(s, data=data) + d.attrs["type"] = "file" + d.attrs["path"] = str(v) + d.attrs["dtype"] = dtype + elif v.is_dir(): + tgz_path = Path("%s.tgz" % v) + tf = tarfile.open(tgz_path, "w:gz", dereference=True) + tf.add(v) + tf.close() + import numpy as np + d = f.create_dataset(s, data=np.void( + tgz_path.read_bytes())) + d.attrs["type"] = "dir" + d.attrs["path"] = str(v) + d.attrs["dtype"] = "binary" + else: + d = f.create_dataset(s, data=v) + d.attrs["type"] = "data" + path_list.append({"dflow_list_item": h5_name, "order": slices or 0}) if sign.type in [str, Path]: os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True) if slices is None: