Skip to content

Commit

Permalink
Fix generalized dice computation (#7970)
Browse files Browse the repository at this point in the history
Fixes #7966 

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Suraj Pai <[email protected]>
Signed-off-by: Suraj Pai <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
surajpaib and KumoLiu authored Sep 7, 2024
1 parent b539cbb commit d02ba11
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 94 deletions.
125 changes: 77 additions & 48 deletions monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,47 @@
import torch

from monai.metrics.utils import do_metric_reduction, ignore_background
from monai.utils import MetricReduction, Weight, look_up_option
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option

from .metric import CumulativeIterationMetric


class GeneralizedDiceScore(CumulativeIterationMetric):
"""Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
"""
Compute the Generalized Dice Score metric between tensors.
This metric is the complement of the Generalized Dice Loss defined in:
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
loss function for highly unbalanced segmentations. DLMIA 2017.
loss function for highly unbalanced segmentations. DLMIA 2017.
The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
or batch-first tensors, i.e., CHW[D] or BCHW[D].
The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
Args:
include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
include_background: Whether to include the background class (assumed to be in channel 0) in the
score computation. Defaults to True.
reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
reduction: Define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground truth volume into a weight factor. Defaults to ``"square"``.
Raises:
ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
ValueError: When the `reduction` is not one of MetricReduction enum.
"""

@deprecated_arg_default(
"reduction",
old_default=MetricReduction.MEAN_BATCH,
new_default=MetricReduction.MEAN,
since="1.4.0",
replaced="1.5.0",
msg_suffix=(
"Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
"If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
),
)
def __init__(
self,
include_background: bool = True,
Expand All @@ -50,79 +63,90 @@ def __init__(
) -> None:
super().__init__()
self.include_background = include_background
reduction_options = [
"none",
"mean_batch",
"sum_batch",
MetricReduction.NONE,
MetricReduction.MEAN_BATCH,
MetricReduction.SUM_BATCH,
]
self.reduction = reduction
if self.reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")
self.reduction = look_up_option(reduction, MetricReduction)
self.weight_type = look_up_option(weight_type, Weight)
self.sum_over_classes = self.reduction in {
MetricReduction.SUM,
MetricReduction.MEAN,
MetricReduction.MEAN_CHANNEL,
MetricReduction.SUM_CHANNEL,
}

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.
Args:
y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
Returns:
torch.Tensor: Generalized Dice Score averaged across batch and class
Raises:
ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
"""
return compute_generalized_dice(
y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
y_pred=y_pred,
y=y,
include_background=self.include_background,
weight_type=self.weight_type,
sum_over_classes=self.sum_over_classes,
)

@deprecated_arg(
"reduction",
since="1.3.3",
removed="1.7.0",
msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
)
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
"""
Execute reduction logic for the output of `compute_generalized_dice`.
Args:
reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
Defaults to ``"mean"``. If "none", will not do reduction.
Returns:
torch.Tensor: Aggregated metric value.
Raises:
ValueError: If the data to aggregate is not a PyTorch Tensor.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("The data to aggregate must be a PyTorch Tensor.")

# Validate reduction argument if specified
if reduction is not None:
reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
if reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")

# Do metric reduction and return
f, _ = do_metric_reduction(data, reduction or self.reduction)
f, _ = do_metric_reduction(data, self.reduction)

return f


def compute_generalized_dice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
weight_type: Weight | str = Weight.SQUARE,
sum_over_classes: bool = False,
) -> torch.Tensor:
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.
Args:
y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background (bool, optional): whether to include score computation on the first channel of the
y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background: Whether to include score computation on the first channel of the
predicted output. Defaults to True.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
transform ground truth volume into a weight factor. Defaults to ``"square"``.
sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
Returns:
torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
Raises:
ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
or `y_pred` and `y` don't have the same shape.
"""
# Ensure tensors have at least 3 dimensions and have the same shape
Expand Down Expand Up @@ -158,16 +182,21 @@ def compute_generalized_dice(
b[infs] = 0
b[infs] = torch.max(b)

# Compute the weighted numerator and denominator, summing along the class axis
numer = 2.0 * (intersection * w).sum(dim=1)
denom = (denominator * w).sum(dim=1)
# Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
if sum_over_classes:
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
denom = (denominator * w).sum(dim=1, keepdim=True)
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
else:
numer = 2.0 * (intersection * w)
denom = denominator * w
y_pred_o = y_pred_o

# Compute the score
generalized_dice_score = numer / denom

# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
# Where denom == 0 but the prediction volume is not 0, score is 0
y_pred_o = y_pred_o.sum(dim=-1)
denom_zeros = denom == 0
generalized_dice_score[denom_zeros] = torch.where(
(y_pred_o == 0)[denom_zeros],
Expand Down
Loading

0 comments on commit d02ba11

Please sign in to comment.