Skip to content

Commit

Permalink
enable gpu load nifti (#8188)
Browse files Browse the repository at this point in the history
Related to #8241 .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
4 people authored Dec 23, 2024
1 parent efff647 commit d36f0c8
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 12 deletions.
86 changes: 77 additions & 9 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from __future__ import annotations

import glob
import gzip
import io
import os
import re
import tempfile
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
Expand Down Expand Up @@ -51,6 +54,9 @@
pydicom, has_pydicom = optional_import("pydicom")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)

cp, has_cp = optional_import("cupy")
kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]


Expand Down Expand Up @@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
)


def _stack_images(image_list: list, meta_dict: dict):
def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
if to_cupy and has_cp:
return cp.concatenate(image_list, axis=channel_dim)
return np.concatenate(image_list, axis=channel_dim)
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
if to_cupy and has_cp:
return cp.stack(image_list, axis=0)
return np.stack(image_list, axis=0)


Expand Down Expand Up @@ -864,12 +874,18 @@ class NibabelReader(ImageReader):
Load NIfTI format images based on Nibabel library.
Args:
as_closest_canonical: if True, load the image as closest to canonical axis format.
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
channel_dim: the channel dimension of the input image, default is None.
this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
if None, `original_channel_dim` will be either `no_channel` or `-1`.
most Nifti files are usually "channel last", no need to specify this argument for them.
as_closest_canonical: if True, load the image as closest to canonical axis format.
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
Default is False. CuPy and Kvikio are required for this option.
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
In practical use, it's recommended to add a warm up call before the actual loading.
A related tutorial will be prepared in the future, and the document will be updated accordingly.
kwargs: additional args for `nibabel.load` API. more details about available args:
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
Expand All @@ -880,14 +896,42 @@ def __init__(
channel_dim: str | int | None = None,
as_closest_canonical: bool = False,
squeeze_non_spatial_dims: bool = False,
to_gpu: bool = False,
**kwargs,
):
super().__init__()
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.as_closest_canonical = as_closest_canonical
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
if to_gpu and (not has_cp or not has_kvikio):
warnings.warn(
"NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
)
to_gpu = False

if to_gpu:
self.warmup_kvikio()

self.to_gpu = to_gpu
self.kwargs = kwargs

def warmup_kvikio(self):
"""
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
This can accelerate the data loading process when `to_gpu` is set to True.
"""
if has_cp and has_kvikio:
a = cp.arange(100)
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file_name = tmp_file.name
f = kvikio.CuFile(tmp_file_name, "w")
f.write(a)
f.close()

b = cp.empty_like(a)
f = kvikio.CuFile(tmp_file_name, "r")
f.read(b)

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
"""
Verify whether the specified file or files format is supported by Nibabel reader.
Expand Down Expand Up @@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
self.filenames = filenames
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand All @@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
"""
# TODO: the actual type is list[np.ndarray | cp.ndarray]
# should figure out how to define correct types without having cupy not found error
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
img_array: list[np.ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
for i, filename in zip(ensure_tuple(img), self.filenames):
header = self._get_meta_dict(i)
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
Expand All @@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
header[MetaKeys.SPACE] = SpaceKeys.RAS
data = self._get_array_data(i)
data = self._get_array_data(i, filename)
if self.squeeze_non_spatial_dims:
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
if data.shape[d - 1] == 1:
Expand All @@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

return _stack_images(img_array, compatible_meta), compatible_meta
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta

def _get_meta_dict(self, img) -> dict:
"""
Expand Down Expand Up @@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img):
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

def _get_array_data(self, img):
def _get_array_data(self, img, filename):
"""
Get the raw array data of the image, converted to Numpy array.
Args:
img: a Nibabel image object loaded from an image file.
"""
filename: file name of the image.
"""
if self.to_gpu:
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
with kvikio.CuFile(filename, "r") as f:
f.read(image)
if filename.endswith(".nii.gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# and may be slower than CPU loading in some cases.
warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.")
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

image = cp.frombuffer(decompressed_data, dtype=cp.uint8)
data_shape = img.shape
data_offset = img.dataobj.offset
data_dtype = img.dataobj.dtype
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
return np.asanyarray(img.dataobj, order="C")


Expand Down
1 change: 0 additions & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def ensure_torch_and_prune_meta(
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray

# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand Down
1 change: 0 additions & 1 deletion monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered: {self.readers}.\n{msg}"
)

img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
Expand Down
19 changes: 19 additions & 0 deletions tests/test_init_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def test_load_image(self):
inst = LoadImaged("image", reader=r)
self.assertIsInstance(inst, LoadImaged)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
def test_load_image_to_gpu(self):
for to_gpu in [True, False]:
instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu)
self.assertIsInstance(instance1, LoadImage)

instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu)
self.assertIsInstance(instance2, LoadImaged)

@SkipIfNoModule("itk")
@SkipIfNoModule("nibabel")
@SkipIfNoModule("PIL")
Expand Down Expand Up @@ -58,6 +69,14 @@ def test_readers(self):
inst = NrrdReader()
self.assertIsInstance(inst, NrrdReader)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
def test_readers_to_gpu(self):
for to_gpu in [True, False]:
inst = NibabelReader(to_gpu=to_gpu)
self.assertIsInstance(inst, NibabelReader)


if __name__ == "__main__":
unittest.main()
41 changes: 40 additions & 1 deletion tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from monai.data.meta_tensor import MetaTensor
from monai.transforms import LoadImage
from monai.utils import optional_import
from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config
from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config

itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator")
Expand Down Expand Up @@ -74,6 +74,22 @@ def get_data(self, _obj):

TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)]

TEST_CASE_GPU_3 = [
{"reader": "nibabelreader", "to_gpu": True},
["test_image.nii", "test_image2.nii", "test_image3.nii"],
(3, 128, 128, 128),
]

TEST_CASE_GPU_4 = [
{"reader": "nibabelreader", "to_gpu": True},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
(3, 128, 128, 128),
]

TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]
Expand Down Expand Up @@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape):
assert_allclose(result.affine, torch.eye(4))
self.assertTupleEqual(result.shape, expected_shape)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
@parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4])
def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result = LoadImage(image_only=True, **input_param)(filenames)
ext = "".join(Path(name).suffixes)
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext))
self.assertEqual(result.meta["space"], "RAS")
assert_allclose(result.affine, torch.eye(4))
self.assertTupleEqual(result.shape, expected_shape)

# verify gpu and cpu loaded data are the same
input_param_cpu = input_param.copy()
input_param_cpu["to_gpu"] = False
result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames)
self.assertTrue(torch.equal(result_cpu, result.cpu()))

@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9])
def test_itk_reader(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
Expand Down

0 comments on commit d36f0c8

Please sign in to comment.