Skip to content

Commit

Permalink
Merge pull request #850 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: add HDF5Datasets type Artifact
  • Loading branch information
zjgemi authored Sep 2, 2024
2 parents fa3201f + 28e855a commit 93c5bc5
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 24 deletions.
5 changes: 3 additions & 2 deletions src/dflow/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .op import OP
from .opio import OPIO, Artifact, BigParameter, OPIOSign, Parameter, NestedDict
from .opio import (OPIO, Artifact, BigParameter, HDF5Datasets, NestedDict,
OPIOSign, Parameter)
from .python_op_template import (FatalError, PythonOPTemplate, Slices,
TransientError, upload_packages)

__all__ = ["OP", "OPIO", "Artifact", "BigParameter", "OPIOSign", "Parameter",
"FatalError", "PythonOPTemplate", "Slices", "TransientError",
"upload_packages", "NestedDict"]
"upload_packages", "NestedDict", "HDF5Datasets"]
4 changes: 3 additions & 1 deletion src/dflow/python/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from ..io import (InputArtifact, InputParameter, OutputArtifact,
OutputParameter, type_to_str)
from ..utils import dict2list, get_key, randstr, s3_config
from .vendor.typeguard import check_type
from .opio import OPIO, Artifact, BigParameter, OPIOSign, Parameter
from .vendor.typeguard import check_type

iwd = os.getcwd()

Expand Down Expand Up @@ -190,6 +190,8 @@ def _check_signature(
ss = Set[Union[str, None]]
elif ss == Set[Path]:
ss = Set[Union[Path, None]]
else:
continue
if isinstance(ss, Parameter):
ss = ss.type
# skip type checking if the variable is None
Expand Down
68 changes: 57 additions & 11 deletions src/dflow/python/opio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import tarfile
from collections.abc import MutableMapping
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
Expand All @@ -8,27 +9,72 @@
from ..io import PVC, type_to_str


class nested_dict:
def __init__(self, type):
self.type = type
class NestedDict:
pass


class NestedDictStr(NestedDict):
pass


class NestedDictPath(NestedDict):
pass


class HDF5Dataset:
def __init__(self, dataset):
self.dataset = dataset

def get_data(self):
data = self.dataset[()]
if self.dataset.attrs.get("dtype") == "utf-8":
data = data.decode("utf-8")
elif self.dataset.attrs.get("dtype") == "binary":
data = data.tobytes()
return data

def recover(self):
if self.dataset.attrs["type"] == "file":
path = Path(self.dataset.attrs["path"])
if path.is_absolute():
path = path.relative_to(path.root)
path.parent.mkdir(parents=True, exist_ok=True)
data = self.get_data()
if isinstance(data, str):
path.write_text(data)
elif isinstance(data, bytes):
path.write_bytes(data)
return path
elif self.dataset.attrs["type"] == "dir":
path = Path(self.dataset.attrs["path"])
if path.is_absolute():
path = path.relative_to(path.root)
path.parent.mkdir(parents=True, exist_ok=True)
tgz_path = path.parent / (path.name + ".tgz")
tgz_path.write_bytes(self.get_data())
tf = tarfile.open(tgz_path, "r:gz")
tf.extractall(".")
tf.close()
return path
else:
return self.get_data()

def __repr__(self):
return "dflow.python.NestedDict[%s]" % type_to_str(self.type)

def __eq__(self, other):
if not isinstance(other, nested_dict):
return False
return self.type == other.type
class HDF5Datasets:
pass


NestedDict = {
str: nested_dict(str),
Path: nested_dict(Path),
str: NestedDictStr,
Path: NestedDictPath,
}

ArtifactAllowedTypes = [str, Path, Set[str], Set[Path], List[str], List[Path],
Dict[str, str], Dict[str, Path], NestedDict[str],
NestedDict[Path]]
for t in ArtifactAllowedTypes.copy():
ArtifactAllowedTypes.append(Union[t, HDF5Datasets])
ArtifactAllowedTypes.append(HDF5Datasets)


@CustomHandler.handles
Expand Down
85 changes: 75 additions & 10 deletions src/dflow/python/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
import shutil
import signal
import tarfile
import traceback
import uuid
from pathlib import Path
from typing import Dict, List, Set
from typing import Dict, List, Set, Union

from ..common import jsonpickle
from ..config import config
from ..utils import (artifact_classes, assemble_path_object,
catalog_of_local_artifact, convert_dflow_list, copy_file,
expand, flatten, randstr, remove_empty_dir_tag)
from .opio import Artifact, BigParameter, NestedDict, Parameter
from .opio import (Artifact, BigParameter, HDF5Dataset, HDF5Datasets,
NestedDict, Parameter)


def get_slices(path_object, slices):
Expand Down Expand Up @@ -78,7 +80,35 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",

path_object = get_slices(path_object, slices)

if sign.type in [str, Path]:
sign_type = sign.type
if getattr(sign_type, "__origin__", None) == Union:
args = sign_type.__args__
if HDF5Datasets in args:
if isinstance(path_object, list) and all([isinstance(
p, str) and p.endswith(".h5") for p in path_object]):
sign_type = HDF5Datasets
elif args[0] == HDF5Datasets:
sign_type = args[1]
elif args[1] == HDF5Datasets:
sign_type = args[0]

if sign_type == HDF5Datasets:
import h5py
assert isinstance(path_object, list)
res = None
for path in path_object:
f = h5py.File(path, "r")
datasets = {k: HDF5Dataset(f[k]) for k in f.keys()}
datasets = expand(datasets)
if isinstance(datasets, list):
if res is None:
res = []
res += datasets
elif isinstance(datasets, dict):
if res is None:
res = {}
res.update(datasets)
if sign_type in [str, Path]:
if path_object is None or isinstance(path_object, str):
res = path_object
elif isinstance(path_object, list) and len(path_object) == 1 and (
Expand All @@ -87,8 +117,8 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
res = path_object[0]
else:
res = art_path
res = path_or_none(res) if sign.type == Path else res
elif sign.type in [List[str], List[Path], Set[str], Set[Path]]:
res = path_or_none(res) if sign_type == Path else res
elif sign_type in [List[str], List[Path], Set[str], Set[Path]]:
if path_object is None:
return None
elif isinstance(path_object, str):
Expand All @@ -99,17 +129,17 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
else:
res = list(flatten(path_object).values())

if sign.type == List[str]:
if sign_type == List[str]:
pass
elif sign.type == List[Path]:
elif sign_type == List[Path]:
res = path_or_none(res)
elif sign.type == Set[str]:
elif sign_type == Set[str]:
res = set(res)
else:
res = set(path_or_none(res))
elif sign.type in [Dict[str, str], NestedDict[str]]:
elif sign_type in [Dict[str, str], NestedDict[str]]:
res = path_object
elif sign.type in [Dict[str, Path], NestedDict[Path]]:
elif sign_type in [Dict[str, Path], NestedDict[Path]]:
res = path_or_none(path_object)

if res is None:
Expand Down Expand Up @@ -169,6 +199,41 @@ def slice_to_dir(slice):
def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
create_dir=False):
path_list = []
if sign.type == HDF5Datasets:
import h5py
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
h5_name = "%s.h5" % uuid.uuid4()
h5_path = '%s/outputs/artifacts/%s/%s' % (data_root, name, h5_name)
with h5py.File(h5_path, "w") as f:
for s, v in flatten(value).items():
if isinstance(v, Path):
if v.is_file():
try:
data = v.read_text(encoding="utf-8")
dtype = "utf-8"
except Exception:
import numpy as np
data = np.void(v.read_bytes())
dtype = "binary"
d = f.create_dataset(s, data=data)
d.attrs["type"] = "file"
d.attrs["path"] = str(v)
d.attrs["dtype"] = dtype
elif v.is_dir():
tgz_path = Path("%s.tgz" % v)
tf = tarfile.open(tgz_path, "w:gz", dereference=True)
tf.add(v)
tf.close()
import numpy as np
d = f.create_dataset(s, data=np.void(
tgz_path.read_bytes()))
d.attrs["type"] = "dir"
d.attrs["path"] = str(v)
d.attrs["dtype"] = "binary"
else:
d = f.create_dataset(s, data=v)
d.attrs["type"] = "data"
path_list.append({"dflow_list_item": h5_name, "order": slices or 0})
if sign.type in [str, Path]:
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
if slices is None:
Expand Down

0 comments on commit 93c5bc5

Please sign in to comment.