Skip to content

Commit

Permalink
reformat to add gpu load support on nibabelreader
Browse files Browse the repository at this point in the history
Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv committed Nov 8, 2024
1 parent da41742 commit f453158
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 130 deletions.
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from .folder_layout import FolderLayout, FolderLayoutBase
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader
from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader
from .image_writer import (
SUPPORTED_WRITERS,
ImageWriter,
Expand Down
143 changes: 44 additions & 99 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
cp, has_cp = optional_import("cupy")
kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]


class ImageReader(ABC):
Expand Down Expand Up @@ -155,6 +155,17 @@ def _stack_images(image_list: list, meta_dict: dict):
return np.stack(image_list, axis=0)


def _stack_gpu_images(image_list: list, meta_dict: dict):
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
return cp.concatenate(image_list, axis=channel_dim)
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
return cp.stack(image_list, axis=0)


@require_pkg(pkg_name="itk")
class ITKReader(ImageReader):
"""
Expand Down Expand Up @@ -887,12 +898,15 @@ def __init__(
channel_dim: str | int | None = None,
as_closest_canonical: bool = False,
squeeze_non_spatial_dims: bool = False,
gpu_load: bool = False,
**kwargs,
):
super().__init__()
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.as_closest_canonical = as_closest_canonical
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
# TODO: add warning if not have required libs
self.gpu_load = gpu_load
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand Down Expand Up @@ -923,6 +937,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
self.filenames = filenames
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand All @@ -946,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img_array: list[np.ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
for i, filename in zip(ensure_tuple(img), self.filenames):
header = self._get_meta_dict(i)
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
Expand All @@ -956,7 +971,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
header[MetaKeys.SPACE] = SpaceKeys.RAS
data = self._get_array_data(i)
data = self._get_array_data(i, filename)
if self.squeeze_non_spatial_dims:
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
if data.shape[d - 1] == 1:
Expand All @@ -969,7 +984,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
else:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

if self.gpu_load:
return _stack_gpu_images(img_array, compatible_meta), compatible_meta
return _stack_images(img_array, compatible_meta), compatible_meta

def _get_meta_dict(self, img) -> dict:
Expand Down Expand Up @@ -1022,111 +1038,40 @@ def _get_spatial_shape(self, img):
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

def _get_array_data(self, img):
def _get_array_data(self, img, filename):
"""
Get the raw array data of the image, converted to Numpy array.
Args:
img: a Nibabel image object loaded from an image file.
"""
if self.gpu_load:
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
# suggestion from Ming: more tests, diff size
# cucim + nifti
with kvikio.CuFile(filename, "r") as f:
f.read(image)
if filename.endswith(".gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# but it's still faster than Nibabel's default reader.
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
# since it's waste times especially in training
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

file_size = len(decompressed_data)
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
data_shape = img.shape
data_offset = img.dataobj.offset
data_dtype = img.dataobj.dtype
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
return np.asanyarray(img.dataobj, order="C")


@require_pkg(pkg_name="nibabel")
@require_pkg(pkg_name="cupy")
@require_pkg(pkg_name="kvikio")
class NibabelGPUReader(NibabelReader):

def read(self, filename: PathLike, **kwargs):
"""
Read image data from specified file or files, it can read a list of images
and stack them together as multi-channel data in `get_data()`.
Note that the returned object is Nibabel image object or list of Nibabel image objects.
Args:
data: file name.
"""
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
with kvikio.CuFile(filename, "r") as f:
f.read(image)

if filename.endswith(".gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# but it's still faster than Nibabel's default reader.
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
# since it's waste times especially in training
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

file_size = len(decompressed_data)
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
return image

def get_data(self, img):
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to present the output metadata.
Args:
img: a Nibabel image object loaded from an image file.
"""

# TODO: use a formal way for device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

header = self._get_header(img)
data_offset = header.get_data_offset()
data_shape = header.get_data_shape()
data_dtype = header.get_data_dtype()
affine = header.get_best_affine()
meta = {}
meta[MetaKeys.AFFINE] = affine
meta[MetaKeys.ORIGINAL_AFFINE] = affine
# TODO: as_closest_canonical
# TODO: correct_nifti_header_if_necessary
meta[MetaKeys.SPATIAL_SHAPE] = data_shape
# TODO: figure out why always RAS for NibabelReader ?
# meta[MetaKeys.SPACE] = SpaceKeys.RAS

data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F")
# TODO: check channel
# if self.squeeze_non_spatial_dims:
if self.channel_dim is None: # default to "no_channel" or -1
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1
)
else:
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim

return MetaTensor(data, affine=affine, meta=meta, device=device)

def _get_header(self, img):
"""
Get the all the metadata of the image and convert to dict type.
Args:
img: a Nibabel image object loaded from an image file.
"""
header_bytes = cp.asnumpy(img[:348])
header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes))
# swap to little endian as PyTorch doesn't support big endian
try:
header = header.as_byteswapped("<")
except ValueError:
pass
return header


class NumpyReader(ImageReader):
"""
Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects.
Expand Down
13 changes: 9 additions & 4 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,12 @@ def clone(self, **kwargs):

@staticmethod
def ensure_torch_and_prune_meta(
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
im: NdarrayTensor,
meta: dict | None,
simple_keys: bool = False,
pattern: str | None = None,
sep: str = ".",
device: None | str | torch.device = None,
):
"""
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
Expand All @@ -547,13 +552,13 @@ def ensure_torch_and_prune_meta(
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.
device: target device to put the Tensor data.
Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray

img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray
# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand All @@ -565,7 +570,7 @@ def ensure_torch_and_prune_meta(
if simple_keys:
# ensure affine is of type `torch.Tensor`
if MetaKeys.AFFINE in meta:
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking
remove_extra_metadata(meta) # bc-breaking

if pattern is not None:
Expand Down
31 changes: 5 additions & 26 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ImageReader,
ITKReader,
NibabelReader,
NibabelGPUReader,
NrrdReader,
NumpyReader,
PILReader,
Expand Down Expand Up @@ -140,6 +139,7 @@ def __init__(
prune_meta_pattern: str | None = None,
prune_meta_sep: str = ".",
expanduser: bool = True,
device: None | str | torch.device = None,
*args,
**kwargs,
) -> None:
Expand All @@ -164,6 +164,7 @@ def __init__(
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.
args: additional parameters for reader if providing a reader name.
device: target device to put the loaded image.
kwargs: additional parameters for reader if providing a reader name.
Note:
Expand All @@ -185,6 +186,7 @@ def __init__(
self.pattern = prune_meta_pattern
self.sep = prune_meta_sep
self.expanduser = expanduser
self.device = device

self.readers: list[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -257,18 +259,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
)
img, err = None, []
if reader is not None:
if isinstance(reader, NibabelGPUReader):
# TODO: handle multiple filenames later
buffer = reader.read(filename[0])
img = reader.get_data(buffer)
img.meta[Key.FILENAME_OR_OBJ] = filename[0]
# TODO: check ensure channel first
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
if self.image_only:
return img
return img, img.meta

img = reader.read(filename) # runtime specified reader
else:
for reader in self.readers[::-1]:
Expand All @@ -278,17 +268,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
break
else: # try the user designated readers
try:
if isinstance(reader, NibabelGPUReader):
# TODO: handle multiple filenames later
buffer = reader.read(filename[0])
img = reader.get_data(buffer)
img.meta[Key.FILENAME_OR_OBJ] = filename[0]
# TODO: check ensure channel first
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
if self.image_only:
return img
return img, img.meta
img = reader.read(filename)
except Exception as e:
err.append(traceback.format_exc())
Expand All @@ -312,15 +291,15 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
)
img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0]
if not isinstance(meta_data, dict):
raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")

meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
img = MetaTensor.ensure_torch_and_prune_meta(
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device
)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
Expand Down

0 comments on commit f453158

Please sign in to comment.