Skip to content

Commit

Permalink
Using keys in .npz files for NumpyReader (#7148)
Browse files Browse the repository at this point in the history
Fixes #7147.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Oct 20, 2023
1 parent 6e5fdc0 commit 77b1759
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
2 changes: 1 addition & 1 deletion monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img = np.load(name, allow_pickle=True, **kwargs_)
if Path(name).name.endswith(".npz"):
# load expected items from NPZ file
npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys
npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys
for k in npz_keys:
img_.append(img[k])
else:
Expand Down
36 changes: 23 additions & 13 deletions tests/test_numpy_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from monai.data import DataLoader, Dataset, NumpyReader
from monai.transforms import LoadImaged
from monai.transforms import LoadImage, LoadImaged
from tests.utils import assert_allclose


Expand Down Expand Up @@ -97,22 +97,32 @@ def test_kwargs(self):

def test_dataloader(self):
test_data = np.random.randint(0, 256, size=[3, 4, 5])
datalist = []
datalist_dict, datalist_array = [], []
with tempfile.TemporaryDirectory() as tempdir:
for i in range(4):
filepath = os.path.join(tempdir, f"test_data{i}.npz")
np.savez(filepath, test_data)
datalist.append({"image": filepath})

num_workers = 2 if sys.platform == "linux" else 0
loader = DataLoader(
Dataset(data=datalist, transform=LoadImaged(keys="image", reader=NumpyReader())),
batch_size=2,
num_workers=num_workers,
)
for d in loader:
for c in d["image"]:
assert_allclose(c, test_data, type_test=False)
datalist_dict.append({"image": filepath})
datalist_array.append(filepath)

num_workers = 2 if sys.platform == "linux" else 0
loader = DataLoader(
Dataset(data=datalist_dict, transform=LoadImaged(keys="image", reader=NumpyReader())),
batch_size=2,
num_workers=num_workers,
)
for d in loader:
for c in d["image"]:
assert_allclose(c, test_data, type_test=False)

loader = DataLoader(
Dataset(data=datalist_array, transform=LoadImage(reader=NumpyReader())),
batch_size=2,
num_workers=num_workers,
)
for d in loader:
for c in d:
assert_allclose(c, test_data, type_test=False)

def test_channel_dim(self):
test_data = np.random.randint(0, 256, size=[3, 4, 5, 2])
Expand Down

0 comments on commit 77b1759

Please sign in to comment.