Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any

import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor
Expand Down Expand Up @@ -68,7 +67,7 @@ class TestTimeAugmentation:
Args:
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
. All random transforms must be of type `InvertibleTransform`.
When `apply_inverse_to_pred` is True, all random transforms must be of type `InvertibleTransform`.
batch_size: number of realizations to infer at once.
num_workers: how many subprocesses to use for data.
inferrer_fn: function to use to perform inference.
Expand All @@ -92,6 +91,11 @@ class TestTimeAugmentation:
will return the full data. Dimensions will be same size as when passing a single image through
`inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
progress: whether to display a progress bar.
apply_inverse_to_pred: whether to apply inverse transformations to the predictions.
If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions
back to the original spatial reference.
If the prediction is non-spatial (e.g. classification label or score), this should be `False` to
aggregate the raw predictions directly. Defaults to `True`.

Example:
.. code-block:: python
Expand Down Expand Up @@ -125,6 +129,7 @@ def __init__(
post_func: Callable = _identity,
return_full_data: bool = False,
progress: bool = True,
apply_inverse_to_pred: bool = True,
) -> None:
self.transform = transform
self.batch_size = batch_size
Expand All @@ -134,6 +139,7 @@ def __init__(
self.image_key = image_key
self.return_full_data = return_full_data
self.progress = progress
self.apply_inverse_to_pred = apply_inverse_to_pred
self._pred_key = CommonKeys.PRED
self.inverter = Invertd(
keys=self._pred_key,
Expand All @@ -152,20 +158,23 @@ def __init__(

def _check_transforms(self):
"""Should be at least 1 random transform, and all random transforms should be invertible."""
ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
randoms = np.array([isinstance(t, Randomizable) for t in ts])
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
# check at least 1 random
if sum(randoms) == 0:
transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
warns = []
randoms = []

for idx, t in enumerate(transforms):
if isinstance(t, Randomizable):
randoms.append(t)
if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
warns.append(f"Transform #{idx} (type {type(t).__name__}) is random but not invertible.")

if len(randoms) == 0:
warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")

if len(warns) > 0:
warnings.warn(
"TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms."
"TTA has encountered issues with the given transforms:\n " + "\n ".join(warns), stacklevel=2
)
# check that whenever randoms is True, invertibles is also true
for r, i in zip(randoms, invertibles):
if r and not i:
warnings.warn(
f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}"
)

def __call__(
self, data: dict[str, Any], num_examples: int = 10
Expand Down Expand Up @@ -199,7 +208,10 @@ def __call__(
for b in tqdm(dl) if has_tqdm and self.progress else dl:
# do model forward pass
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
if self.apply_inverse_to_pred:
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
else:
outs.extend([i[self._pred_key] for i in decollate_batch(b)])

output: NdarrayOrTensor = stack(outs, 0)

Expand Down
39 changes: 38 additions & 1 deletion tests/integration/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_test_time_augmentation(self):
# output might be different size, so pad so that they match
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device)
loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

Expand Down Expand Up @@ -181,6 +181,43 @@ def test_image_no_label(self):
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image")
tta(self.get_data(1, (20, 20), include_label=False))

def test_non_spatial_output(self):
"""
Test TTA for non-spatial output (e.g., classification scores).
Verifies that setting `apply_inverse_to_pred=False` correctly aggregates
predictions without attempting spatial inversion.
"""
input_size = (20, 20)
data = {"image": np.random.rand(1, *input_size).astype(np.float32)}

transforms = Compose(
[EnsureChannelFirstd("image", channel_dim="no_channel"), RandFlipd("image", prob=1.0, spatial_axis=0)]
)

def mock_classifier(x):
batch_size = x.shape[0]
return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device)

tt_aug = TestTimeAugmentation(
transform=transforms,
batch_size=2,
num_workers=0,
inferrer_fn=mock_classifier,
device="cpu",
orig_key="image",
apply_inverse_to_pred=False,
return_full_data=False,
)
mode, mean, std, vvc = tt_aug(data, num_examples=4)

self.assertEqual(mean.shape, (2,))
np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6)
np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6)

tt_aug.return_full_data = True
full_output = tt_aug(data, num_examples=4)
self.assertEqual(full_output.shape, (4, 2))


if __name__ == "__main__":
unittest.main()
Loading