Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable gpu load nifti #8188

Open
wants to merge 21 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
reformat to add gpu load support on nibabelreader
Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv committed Nov 8, 2024
commit f4531588232449ad9231aff797e857a474f88397
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from .folder_layout import FolderLayout, FolderLayoutBase
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader
from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader
from .image_writer import (
SUPPORTED_WRITERS,
ImageWriter,
Expand Down
143 changes: 44 additions & 99 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
cp, has_cp = optional_import("cupy")
kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]


class ImageReader(ABC):
Expand Down Expand Up @@ -155,6 +155,17 @@ def _stack_images(image_list: list, meta_dict: dict):
return np.stack(image_list, axis=0)


def _stack_gpu_images(image_list: list, meta_dict: dict):
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
return cp.concatenate(image_list, axis=channel_dim)
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
return cp.stack(image_list, axis=0)


@require_pkg(pkg_name="itk")
class ITKReader(ImageReader):
"""
Expand Down Expand Up @@ -887,12 +898,15 @@ def __init__(
channel_dim: str | int | None = None,
as_closest_canonical: bool = False,
squeeze_non_spatial_dims: bool = False,
gpu_load: bool = False,
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
super().__init__()
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.as_closest_canonical = as_closest_canonical
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
# TODO: add warning if not have required libs
self.gpu_load = gpu_load
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand Down Expand Up @@ -923,6 +937,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
self.filenames = filenames
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand All @@ -946,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img_array: list[np.ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
for i, filename in zip(ensure_tuple(img), self.filenames):
header = self._get_meta_dict(i)
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
Expand All @@ -956,7 +971,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
header[MetaKeys.SPACE] = SpaceKeys.RAS
data = self._get_array_data(i)
data = self._get_array_data(i, filename)
if self.squeeze_non_spatial_dims:
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
if data.shape[d - 1] == 1:
Expand All @@ -969,7 +984,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
else:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

if self.gpu_load:
return _stack_gpu_images(img_array, compatible_meta), compatible_meta
return _stack_images(img_array, compatible_meta), compatible_meta

def _get_meta_dict(self, img) -> dict:
Expand Down Expand Up @@ -1022,111 +1038,40 @@ def _get_spatial_shape(self, img):
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

def _get_array_data(self, img):
def _get_array_data(self, img, filename):
"""
Get the raw array data of the image, converted to Numpy array.

Args:
img: a Nibabel image object loaded from an image file.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add docstring for filename here.


"""
if self.gpu_load:
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be an issue when data dtype is not uint8?

# suggestion from Ming: more tests, diff size
# cucim + nifti
with kvikio.CuFile(filename, "r") as f:
f.read(image)
if filename.endswith(".gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# but it's still faster than Nibabel's default reader.
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
# since it's waste times especially in training
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

file_size = len(decompressed_data)
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
data_shape = img.shape
data_offset = img.dataobj.offset
data_dtype = img.dataobj.dtype
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
return np.asanyarray(img.dataobj, order="C")


@require_pkg(pkg_name="nibabel")
@require_pkg(pkg_name="cupy")
@require_pkg(pkg_name="kvikio")
class NibabelGPUReader(NibabelReader):

def read(self, filename: PathLike, **kwargs):
"""
Read image data from specified file or files, it can read a list of images
and stack them together as multi-channel data in `get_data()`.
Note that the returned object is Nibabel image object or list of Nibabel image objects.

Args:
data: file name.

"""
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
with kvikio.CuFile(filename, "r") as f:
f.read(image)

if filename.endswith(".gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# but it's still faster than Nibabel's default reader.
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
# since it's waste times especially in training
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

file_size = len(decompressed_data)
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
return image

def get_data(self, img):
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to present the output metadata.

Args:
img: a Nibabel image object loaded from an image file.

"""

# TODO: use a formal way for device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

header = self._get_header(img)
data_offset = header.get_data_offset()
data_shape = header.get_data_shape()
data_dtype = header.get_data_dtype()
affine = header.get_best_affine()
meta = {}
meta[MetaKeys.AFFINE] = affine
meta[MetaKeys.ORIGINAL_AFFINE] = affine
# TODO: as_closest_canonical
# TODO: correct_nifti_header_if_necessary
meta[MetaKeys.SPATIAL_SHAPE] = data_shape
# TODO: figure out why always RAS for NibabelReader ?
# meta[MetaKeys.SPACE] = SpaceKeys.RAS

data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F")
# TODO: check channel
# if self.squeeze_non_spatial_dims:
if self.channel_dim is None: # default to "no_channel" or -1
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1
)
else:
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim

return MetaTensor(data, affine=affine, meta=meta, device=device)

def _get_header(self, img):
"""
Get the all the metadata of the image and convert to dict type.

Args:
img: a Nibabel image object loaded from an image file.

"""
header_bytes = cp.asnumpy(img[:348])
header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes))
# swap to little endian as PyTorch doesn't support big endian
try:
header = header.as_byteswapped("<")
except ValueError:
pass
return header


class NumpyReader(ImageReader):
"""
Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects.
Expand Down
13 changes: 9 additions & 4 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,12 @@ def clone(self, **kwargs):

@staticmethod
def ensure_torch_and_prune_meta(
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
im: NdarrayTensor,
meta: dict | None,
simple_keys: bool = False,
pattern: str | None = None,
sep: str = ".",
device: None | str | torch.device = None,
):
"""
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
Expand All @@ -547,13 +552,13 @@ def ensure_torch_and_prune_meta(
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.
device: target device to put the Tensor data.

Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray

img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray
# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand All @@ -565,7 +570,7 @@ def ensure_torch_and_prune_meta(
if simple_keys:
# ensure affine is of type `torch.Tensor`
if MetaKeys.AFFINE in meta:
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking
remove_extra_metadata(meta) # bc-breaking

if pattern is not None:
Expand Down
31 changes: 5 additions & 26 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ImageReader,
ITKReader,
NibabelReader,
NibabelGPUReader,
NrrdReader,
NumpyReader,
PILReader,
Expand Down Expand Up @@ -140,6 +139,7 @@ def __init__(
prune_meta_pattern: str | None = None,
prune_meta_sep: str = ".",
expanduser: bool = True,
device: None | str | torch.device = None,
*args,
**kwargs,
) -> None:
Expand All @@ -164,6 +164,7 @@ def __init__(
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.
args: additional parameters for reader if providing a reader name.
device: target device to put the loaded image.
kwargs: additional parameters for reader if providing a reader name.

Note:
Expand All @@ -185,6 +186,7 @@ def __init__(
self.pattern = prune_meta_pattern
self.sep = prune_meta_sep
self.expanduser = expanduser
self.device = device

self.readers: list[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -257,18 +259,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
)
img, err = None, []
if reader is not None:
if isinstance(reader, NibabelGPUReader):
# TODO: handle multiple filenames later
buffer = reader.read(filename[0])
img = reader.get_data(buffer)
img.meta[Key.FILENAME_OR_OBJ] = filename[0]
# TODO: check ensure channel first
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
if self.image_only:
return img
return img, img.meta

img = reader.read(filename) # runtime specified reader
else:
for reader in self.readers[::-1]:
Expand All @@ -278,17 +268,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
break
else: # try the user designated readers
try:
if isinstance(reader, NibabelGPUReader):
# TODO: handle multiple filenames later
buffer = reader.read(filename[0])
img = reader.get_data(buffer)
img.meta[Key.FILENAME_OR_OBJ] = filename[0]
# TODO: check ensure channel first
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
if self.image_only:
return img
return img, img.meta
img = reader.read(filename)
except Exception as e:
err.append(traceback.format_exc())
Expand All @@ -312,15 +291,15 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
)
img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0]
if not isinstance(meta_data, dict):
raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")

meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
img = MetaTensor.ensure_torch_and_prune_meta(
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device
)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
Expand Down
Loading