Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable gpu load nifti #8188

Open
wants to merge 21 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests
Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv committed Dec 13, 2024
commit d052a5f1494e6a032540fee4c9169c49ced0e381
17 changes: 9 additions & 8 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
from __future__ import annotations

import glob
import os
import re
import gzip
import io
import os
import re
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

Expand All @@ -36,14 +37,14 @@
from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg

if TYPE_CHECKING:
import cupy as cp
import itk
import kvikio
import nibabel as nib
import nrrd
import pydicom
from nibabel.nifti1 import Nifti1Image
from PIL import Image as PILImage
import cupy as cp
import kvikio

has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True
else:
Expand Down Expand Up @@ -948,7 +949,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_.append(img) # type: ignore
return img_ if len(filenames) > 1 else img_[0]

def get_data(self, img) -> tuple[np.ndarray, dict]:
def get_data(self, img) -> tuple[np.ndarray | "cp.ndarray", dict]:
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
Expand All @@ -960,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.

"""
img_array: list[np.ndarray] = []
img_array: list[np.ndarray | "cp.ndarray"] = []
compatible_meta: dict = {}

for i, filename in zip(ensure_tuple(img), self.filenames):
Expand Down Expand Up @@ -1047,12 +1048,12 @@ def _get_array_data(self, img, filename):
img: a Nibabel image object loaded from an image file.

"""
if self.gpu_load:
if self.to_gpu:
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
with kvikio.CuFile(filename, "r") as f:
f.read(image)
if filename.endswith(".gz"):
if filename.endswith(".nii.gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# but it's still faster than Nibabel's default reader.
Expand Down
4 changes: 3 additions & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def ensure_torch_and_prune_meta(
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray
img = convert_to_tensor(
im, track_meta=get_track_meta() and meta is not None, device=device
) # potentially ascontiguousarray
# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand Down
19 changes: 19 additions & 0 deletions tests/test_init_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def test_load_image(self):
inst = LoadImaged("image", reader=r)
self.assertIsInstance(inst, LoadImaged)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
def test_load_image_to_gpu(self):
for to_gpu in [True, False]:
instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu)
self.assertIsInstance(instance1, LoadImage)

instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu)
self.assertIsInstance(instance2, LoadImaged)

@SkipIfNoModule("itk")
@SkipIfNoModule("nibabel")
@SkipIfNoModule("PIL")
Expand Down Expand Up @@ -58,6 +69,14 @@ def test_readers(self):
inst = NrrdReader()
self.assertIsInstance(inst, NrrdReader)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
def test_readers_to_gpu(self):
for to_gpu in [True, False]:
inst = NibabelReader(to_gpu=to_gpu)
self.assertIsInstance(inst, NibabelReader)


if __name__ == "__main__":
unittest.main()
41 changes: 40 additions & 1 deletion tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from monai.data.meta_tensor import MetaTensor
from monai.transforms import LoadImage
from monai.utils import optional_import
from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config
from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config

itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator")
Expand Down Expand Up @@ -74,6 +74,22 @@ def get_data(self, _obj):

TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)]

TEST_CASE_GPU_3 = [
{"reader": "nibabelreader", "to_gpu": True},
["test_image.nii", "test_image2.nii", "test_image3.nii"],
(3, 128, 128, 128),
]

TEST_CASE_GPU_4 = [
{"reader": "nibabelreader", "to_gpu": True},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
(3, 128, 128, 128),
]

TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]
Expand Down Expand Up @@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape):
assert_allclose(result.affine, torch.eye(4))
self.assertTupleEqual(result.shape, expected_shape)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
@parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4])
def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result = LoadImage(image_only=True, **input_param)(filenames)
ext = "".join(Path(name).suffixes)
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext))
self.assertEqual(result.meta["space"], "RAS")
assert_allclose(result.affine, torch.eye(4))
self.assertTupleEqual(result.shape, expected_shape)

# verify gpu and cpu loaded data are the same
input_param_cpu = input_param.copy()
input_param_cpu["to_gpu"] = False
result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames)
self.assertTrue(torch.equal(result_cpu, result.cpu()))

@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9])
def test_itk_reader(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
Expand Down
Loading