Skip to content

Commit

Permalink
remove device set
Browse files Browse the repository at this point in the history
Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv committed Dec 20, 2024
1 parent ea2355b commit 6f4e5cb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
8 changes: 2 additions & 6 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def ensure_torch_and_prune_meta(
simple_keys: bool = False,
pattern: str | None = None,
sep: str = ".",
device: None | str | torch.device = None,
):
"""
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
Expand All @@ -552,15 +551,12 @@ def ensure_torch_and_prune_meta(
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.
device: target device to put the Tensor data.
Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(
im, track_meta=get_track_meta() and meta is not None, device=device
) # potentially ascontiguousarray
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray
# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand All @@ -572,7 +568,7 @@ def ensure_torch_and_prune_meta(
if simple_keys:
# ensure affine is of type `torch.Tensor`
if MetaKeys.AFFINE in meta:
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
remove_extra_metadata(meta) # bc-breaking

if pattern is not None:
Expand Down
6 changes: 2 additions & 4 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def __init__(
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.
args: additional parameters for reader if providing a reader name.
device: target device to put the loaded image.
kwargs: additional parameters for reader if providing a reader name.
Note:
Expand All @@ -186,7 +185,6 @@ def __init__(
self.pattern = prune_meta_pattern
self.sep = prune_meta_sep
self.expanduser = expanduser
self.device = device

self.readers: list[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -291,15 +289,15 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
)
img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0]
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
if not isinstance(meta_data, dict):
raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")

meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
img = MetaTensor.ensure_torch_and_prune_meta(
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
Expand Down

0 comments on commit 6f4e5cb

Please sign in to comment.