-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix #8239: Enhance SoftclDiceLoss and SoftDiceclDiceLoss with DiceLoss-compatible API #8703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
0fc51ea
e8a2579
caf39ef
568ec36
3e1a055
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,18 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
| from monai.losses.dice import DiceLoss | ||
| from monai.networks import one_hot | ||
| from monai.utils import LossReduction | ||
| from monai.utils.deprecate_utils import deprecated_arg | ||
|
|
||
|
|
||
| def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore | ||
| """ | ||
|
|
@@ -92,26 +100,6 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: | |
| return skel | ||
|
|
||
|
|
||
| def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: | ||
| """ | ||
| Function to compute soft dice loss | ||
|
|
||
| Adapted from: | ||
| https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22 | ||
|
|
||
| Args: | ||
| y_true: the shape should be BCH(WD) | ||
| y_pred: the shape should be BCH(WD) | ||
|
|
||
| Returns: | ||
| dice loss | ||
| """ | ||
| intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) | ||
| coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) | ||
| soft_dice: torch.Tensor = 1.0 - coeff | ||
| return soft_dice | ||
|
|
||
|
|
||
| class SoftclDiceLoss(_Loss): | ||
| """ | ||
| Compute the Soft clDice loss defined in: | ||
|
|
@@ -121,64 +109,256 @@ class SoftclDiceLoss(_Loss): | |
|
|
||
| Adapted from: | ||
| https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 | ||
|
|
||
| The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). | ||
| Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, | ||
| must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` | ||
| can be 1 or N (one-hot format). | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: | ||
| def __init__( | ||
| self, | ||
| iter_: int = 3, | ||
| smooth_nr: float = 1.0, | ||
| smooth_dr: float = 1.0, | ||
| include_background: bool = True, | ||
| to_onehot_y: bool = False, | ||
| sigmoid: bool = False, | ||
| softmax: bool = False, | ||
| other_act: Callable | None = None, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| iter_: Number of iterations for skeletonization | ||
| smooth: Smoothing parameter | ||
| iter_: Number of iterations for skeletonization. Must be a non-negative integer. | ||
| smooth_nr: a small constant added to the numerator to avoid zero. | ||
| smooth_dr: a small constant added to the denominator to avoid nan. | ||
| include_background: if False, channel index 0 (background category) is excluded from the calculation. | ||
| if the non-background segmentations are small compared to the total image size they can get overwhelmed | ||
| by the signal from the background so excluding it in such cases helps convergence. | ||
| to_onehot_y: whether to convert the ``target`` into the one-hot format, | ||
| using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. | ||
| sigmoid: if True, apply a sigmoid function to the prediction. | ||
| softmax: if True, apply a softmax function to the prediction. | ||
| other_act: callable function to execute other activation layers, Defaults to ``None``. for example: | ||
| ``other_act = torch.tanh``. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
|
|
||
| - ``"none"``: no reduction will be applied. | ||
| - ``"mean"``: the sum of the output will be divided by the number of elements in the output. | ||
| - ``"sum"``: the output will be summed. | ||
|
|
||
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
||
| """ | ||
| super().__init__() | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| if other_act is not None and not callable(other_act): | ||
| raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") | ||
| if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: | ||
| raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") | ||
| if not isinstance(iter_, int): | ||
| raise TypeError(f"iter_ must be an integer but got {type(iter_).__name__}.") | ||
| if iter_ < 0: | ||
| raise ValueError(f"iter_ must be a non-negative integer but got {iter_}.") | ||
| self.iter = iter_ | ||
| self.smooth = smooth | ||
| self.smooth_nr = float(smooth_nr) | ||
| self.smooth_dr = float(smooth_dr) | ||
| self.include_background = include_background | ||
| self.to_onehot_y = to_onehot_y | ||
| self.sigmoid = sigmoid | ||
| self.softmax = softmax | ||
| self.other_act = other_act | ||
|
|
||
| @deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.") | ||
| @deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.") | ||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """ | ||
| Args: | ||
| input: the shape should be BNH[WD], where N is the number of classes. | ||
| target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. | ||
|
|
||
| Raises: | ||
| AssertionError: When input and target (after one hot transform if set) | ||
| have different shapes. | ||
|
|
||
| """ | ||
| n_pred_ch = input.shape[1] | ||
|
|
||
| if self.sigmoid: | ||
| input = torch.sigmoid(input) | ||
|
|
||
| if self.softmax: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2) | ||
| else: | ||
| input = torch.softmax(input, dim=1) | ||
|
|
||
| def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: | ||
| skel_pred = soft_skel(y_pred, self.iter) | ||
| skel_true = soft_skel(y_true, self.iter) | ||
| tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_pred[:, 1:, ...]) + self.smooth | ||
| if self.other_act is not None: | ||
| input = self.other_act(input) | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
|
||
| if not self.include_background: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2) | ||
| else: | ||
| target = target[:, 1:] | ||
| input = input[:, 1:] | ||
|
|
||
| if target.shape != input.shape: | ||
| raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") | ||
|
|
||
| skel_pred = soft_skel(input, self.iter) | ||
| skel_true = soft_skel(target, self.iter) | ||
|
|
||
| # Compute per-batch clDice by reducing over channel and spatial dimensions | ||
| # reduce_axis includes all dimensions except batch (dim 0) | ||
| reduce_axis: list[int] = list(range(1, len(input.shape))) | ||
|
|
||
| tprec = (torch.sum(torch.multiply(skel_pred, target), dim=reduce_axis) + self.smooth_nr) / ( | ||
| torch.sum(skel_pred, dim=reduce_axis) + self.smooth_dr | ||
| ) | ||
| tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_true[:, 1:, ...]) + self.smooth | ||
| tsens = (torch.sum(torch.multiply(skel_true, input), dim=reduce_axis) + self.smooth_nr) / ( | ||
| torch.sum(skel_true, dim=reduce_axis) + self.smooth_dr | ||
| ) | ||
| cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) | ||
| # Add small epsilon for numerical stability in harmonic mean | ||
| cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-7) | ||
|
|
||
| # Apply reduction | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| cl_dice = torch.mean(cl_dice) | ||
| elif self.reduction == LossReduction.SUM.value: | ||
| cl_dice = torch.sum(cl_dice) | ||
| elif self.reduction == LossReduction.NONE.value: | ||
| pass # keep per-batch values | ||
| else: | ||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') | ||
|
|
||
| return cl_dice | ||
|
|
||
|
|
||
| class SoftDiceclDiceLoss(_Loss): | ||
| """ | ||
| Compute the Soft clDice loss defined in: | ||
| Compute both Dice loss and clDice loss, and return the weighted sum of these two losses. | ||
| The details of Dice loss is shown in ``monai.losses.DiceLoss``. | ||
| The details of clDice loss is shown in ``monai.losses.SoftclDiceLoss``. | ||
|
|
||
| Adapted from: | ||
| Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function | ||
| for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) | ||
|
|
||
| Adapted from: | ||
| https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 | ||
| """ | ||
|
|
||
| def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: | ||
| def __init__( | ||
| self, | ||
| iter_: int = 3, | ||
| alpha: float = 0.5, | ||
| smooth_nr: float = 1.0, | ||
| smooth_dr: float = 1.0, | ||
| include_background: bool = True, | ||
| to_onehot_y: bool = False, | ||
| sigmoid: bool = False, | ||
| softmax: bool = False, | ||
| other_act: Callable | None = None, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| iter_: Number of iterations for skeletonization | ||
| smooth: Smoothing parameter | ||
| alpha: Weighing factor for cldice | ||
| iter_: Number of iterations for skeletonization, used by clDice. Must be a non-negative integer. | ||
| alpha: Weighing factor for cldice component. Total loss = (1 - alpha) * dice + alpha * cldice. | ||
| Defaults to 0.5. | ||
| smooth_nr: a small constant added to the numerator to avoid zero, used by both Dice and clDice. | ||
| smooth_dr: a small constant added to the denominator to avoid nan, used by both Dice and clDice. | ||
| include_background: if False, channel index 0 (background category) is excluded from the calculation. | ||
| if the non-background segmentations are small compared to the total image size they can get overwhelmed | ||
| by the signal from the background so excluding it in such cases helps convergence. | ||
| to_onehot_y: whether to convert the ``target`` into the one-hot format, | ||
| using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. | ||
| sigmoid: if True, apply a sigmoid function to the prediction. | ||
| softmax: if True, apply a softmax function to the prediction. | ||
| other_act: callable function to execute other activation layers, Defaults to ``None``. for example: | ||
| ``other_act = torch.tanh``. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
|
|
||
| - ``"none"``: no reduction will be applied. | ||
| - ``"mean"``: the sum of the output will be divided by the number of elements in the output. | ||
| - ``"sum"``: the output will be summed. | ||
|
|
||
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
||
| """ | ||
| super().__init__() | ||
| self.iter = iter_ | ||
| self.smooth = smooth | ||
| self.alpha = alpha | ||
|
|
||
| def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: | ||
| dice = soft_dice(y_true, y_pred, self.smooth) | ||
| skel_pred = soft_skel(y_pred, self.iter) | ||
| skel_true = soft_skel(y_true, self.iter) | ||
| tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_pred[:, 1:, ...]) + self.smooth | ||
| if not 0.0 <= alpha <= 1.0: | ||
| raise ValueError(f"alpha must be in [0, 1] but got {alpha}.") | ||
| self.dice = DiceLoss( | ||
| include_background=include_background, | ||
| to_onehot_y=False, | ||
| sigmoid=sigmoid, | ||
| softmax=softmax, | ||
| other_act=other_act, | ||
| reduction=reduction, | ||
| smooth_nr=smooth_nr, | ||
| smooth_dr=smooth_dr, | ||
| ) | ||
| tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_true[:, 1:, ...]) + self.smooth | ||
| self.cldice = SoftclDiceLoss( | ||
| iter_=iter_, | ||
| smooth_nr=smooth_nr, | ||
| smooth_dr=smooth_dr, | ||
| include_background=include_background, | ||
| to_onehot_y=False, | ||
| sigmoid=sigmoid, | ||
| softmax=softmax, | ||
| other_act=other_act, | ||
| reduction=reduction, | ||
| ) | ||
| cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) | ||
| total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice | ||
| self.alpha = alpha | ||
| self.to_onehot_y = to_onehot_y | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.") | ||
| @deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.") | ||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with the names here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have addressed all the above, thank you! |
||
| """ | ||
| Args: | ||
| input: the shape should be BNH[WD], where N is the number of classes. | ||
| target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. | ||
|
|
||
| Raises: | ||
| ValueError: When number of dimensions for input and target are different. | ||
| ValueError: When number of channels for target is neither 1 nor the same as input. | ||
|
|
||
| """ | ||
| if input.dim() != target.dim(): | ||
| raise ValueError( | ||
| f"the number of dimensions for input and target should be the same, got shape {input.shape} and {target.shape}." | ||
| ) | ||
|
|
||
| if target.shape[1] != 1 and target.shape[1] != input.shape[1]: | ||
| raise ValueError( | ||
| f"number of channels for target is neither 1 nor the same as input, got shape {input.shape} and {target.shape}." | ||
| ) | ||
|
|
||
| if self.to_onehot_y: | ||
| n_pred_ch = input.shape[1] | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
|
||
| dice_loss = self.dice(input, target) | ||
| cldice_loss = self.cldice(input, target) | ||
| total_loss: torch.Tensor = (1.0 - self.alpha) * dice_loss + self.alpha * cldice_loss | ||
|
|
||
| return total_loss | ||
Uh oh!
There was an error while loading. Please reload this page.