From f4531588232449ad9231aff797e857a474f88397 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 8 Nov 2024 08:11:20 +0000 Subject: [PATCH] reformat to add gpu load support on nibabelreader Signed-off-by: Yiheng Wang --- monai/data/__init__.py | 2 +- monai/data/image_reader.py | 143 +++++++++++------------------------ monai/data/meta_tensor.py | 13 +++- monai/transforms/io/array.py | 31 ++------ 4 files changed, 59 insertions(+), 130 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 14d0dfb193..340c5eb8fa 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -50,7 +50,7 @@ from .folder_layout import FolderLayout, FolderLayoutBase from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 68ef5420ae..ae94fcc053 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -58,7 +58,7 @@ cp, has_cp = optional_import("cupy") kvikio, has_kvikio = optional_import("kvikio") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] class ImageReader(ABC): @@ -155,6 +155,17 @@ def _stack_images(image_list: list, meta_dict: dict): return np.stack(image_list, axis=0) +def _stack_gpu_images(image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return cp.concatenate(image_list, axis=channel_dim) + # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return cp.stack(image_list, axis=0) + + @require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ @@ -887,12 +898,15 @@ def __init__( channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, + gpu_load: bool = False, **kwargs, ): super().__init__() self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + # TODO: add warning if not have required libs + self.gpu_load = gpu_load self.kwargs = kwargs def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: @@ -923,6 +937,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = filenames kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -946,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img_array: list[np.ndarray] = [] compatible_meta: dict = {} - for i in ensure_tuple(img): + for i, filename in zip(ensure_tuple(img), self.filenames): header = self._get_meta_dict(i) header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) @@ -956,7 +971,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) + data = self._get_array_data(i, filename) if self.squeeze_non_spatial_dims: for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): if data.shape[d - 1] == 1: @@ -969,7 +984,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - + if self.gpu_load: + return _stack_gpu_images(img_array, compatible_meta), compatible_meta return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> dict: @@ -1022,7 +1038,7 @@ def _get_spatial_shape(self, img): spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) - def _get_array_data(self, img): + def _get_array_data(self, img, filename): """ Get the raw array data of the image, converted to Numpy array. @@ -1030,103 +1046,32 @@ def _get_array_data(self, img): img: a Nibabel image object loaded from an image file. """ + if self.gpu_load: + file_size = os.path.getsize(filename) + image = cp.empty(file_size, dtype=cp.uint8) + # suggestion from Ming: more tests, diff size + # cucim + nifti + with kvikio.CuFile(filename, "r") as f: + f.read(image) + if filename.endswith(".gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # but it's still faster than Nibabel's default reader. + # TODO: can benchmark more, it may no need to do this since we don't have to use .gz + # since it's waste times especially in training + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + file_size = len(decompressed_data) + image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) + data_shape = img.shape + data_offset = img.dataobj.offset + data_dtype = img.dataobj.dtype + return image[data_offset:].view(data_dtype).reshape(data_shape, order="F") return np.asanyarray(img.dataobj, order="C") -@require_pkg(pkg_name="nibabel") -@require_pkg(pkg_name="cupy") -@require_pkg(pkg_name="kvikio") -class NibabelGPUReader(NibabelReader): - - def read(self, filename: PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Nibabel image object or list of Nibabel image objects. - - Args: - data: file name. - - """ - file_size = os.path.getsize(filename) - image = cp.empty(file_size, dtype=cp.uint8) - with kvikio.CuFile(filename, "r") as f: - f.read(image) - - if filename.endswith(".gz"): - # for compressed data, have to tansfer to CPU to decompress - # and then transfer back to GPU. It is not efficient compared to .nii file - # but it's still faster than Nibabel's default reader. - # TODO: can benchmark more, it may no need to do this since we don't have to use .gz - # since it's waste times especially in training - compressed_data = cp.asnumpy(image) - with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: - decompressed_data = gz_file.read() - - file_size = len(decompressed_data) - image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) - return image - - def get_data(self, img): - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to present the output metadata. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - - # TODO: use a formal way for device - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - - header = self._get_header(img) - data_offset = header.get_data_offset() - data_shape = header.get_data_shape() - data_dtype = header.get_data_dtype() - affine = header.get_best_affine() - meta = {} - meta[MetaKeys.AFFINE] = affine - meta[MetaKeys.ORIGINAL_AFFINE] = affine - # TODO: as_closest_canonical - # TODO: correct_nifti_header_if_necessary - meta[MetaKeys.SPATIAL_SHAPE] = data_shape - # TODO: figure out why always RAS for NibabelReader ? - # meta[MetaKeys.SPACE] = SpaceKeys.RAS - - data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F") - # TODO: check channel - # if self.squeeze_non_spatial_dims: - if self.channel_dim is None: # default to "no_channel" or -1 - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - - return MetaTensor(data, affine=affine, meta=meta, device=device) - - def _get_header(self, img): - """ - Get the all the metadata of the image and convert to dict type. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - header_bytes = cp.asnumpy(img[:348]) - header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes)) - # swap to little endian as PyTorch doesn't support big endian - try: - header = header.as_byteswapped("<") - except ValueError: - pass - return header - - class NumpyReader(ImageReader): """ Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ac171e8508..959108eb47 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -532,7 +532,12 @@ def clone(self, **kwargs): @staticmethod def ensure_torch_and_prune_meta( - im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." + im: NdarrayTensor, + meta: dict | None, + simple_keys: bool = False, + pattern: str | None = None, + sep: str = ".", + device: None | str | torch.device = None, ): """ Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, @@ -547,13 +552,13 @@ def ensure_torch_and_prune_meta( sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``. + device: target device to put the Tensor data. Returns: By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray - + img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img @@ -565,7 +570,7 @@ def ensure_torch_and_prune_meta( if simple_keys: # ensure affine is of type `torch.Tensor` if MetaKeys.AFFINE in meta: - meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking + meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking remove_extra_metadata(meta) # bc-breaking if pattern is not None: diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 455e38ac08..2eb00ab38d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -35,7 +35,6 @@ ImageReader, ITKReader, NibabelReader, - NibabelGPUReader, NrrdReader, NumpyReader, PILReader, @@ -140,6 +139,7 @@ def __init__( prune_meta_pattern: str | None = None, prune_meta_sep: str = ".", expanduser: bool = True, + device: None | str | torch.device = None, *args, **kwargs, ) -> None: @@ -164,6 +164,7 @@ def __init__( e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is. args: additional parameters for reader if providing a reader name. + device: target device to put the loaded image. kwargs: additional parameters for reader if providing a reader name. Note: @@ -185,6 +186,7 @@ def __init__( self.pattern = prune_meta_pattern self.sep = prune_meta_sep self.expanduser = expanduser + self.device = device self.readers: list[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default @@ -257,18 +259,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img, err = None, [] if reader is not None: - if isinstance(reader, NibabelGPUReader): - # TODO: handle multiple filenames later - buffer = reader.read(filename[0]) - img = reader.get_data(buffer) - img.meta[Key.FILENAME_OR_OBJ] = filename[0] - # TODO: check ensure channel first - if self.ensure_channel_first: - img = EnsureChannelFirst()(img) - if self.image_only: - return img - return img, img.meta - img = reader.read(filename) # runtime specified reader else: for reader in self.readers[::-1]: @@ -278,17 +268,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader break else: # try the user designated readers try: - if isinstance(reader, NibabelGPUReader): - # TODO: handle multiple filenames later - buffer = reader.read(filename[0]) - img = reader.get_data(buffer) - img.meta[Key.FILENAME_OR_OBJ] = filename[0] - # TODO: check ensure channel first - if self.ensure_channel_first: - img = EnsureChannelFirst()(img) - if self.image_only: - return img - return img, img.meta img = reader.read(filename) except Exception as e: err.append(traceback.format_exc()) @@ -312,7 +291,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) - img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] + img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0] if not isinstance(meta_data, dict): raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.") # make sure all elements in metadata are little endian @@ -320,7 +299,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader img = MetaTensor.ensure_torch_and_prune_meta( - img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep + img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device ) if self.ensure_channel_first: img = EnsureChannelFirst()(img)