Skip to content

Conversation

@ytl0623
Copy link
Contributor

@ytl0623 ytl0623 commented Jan 22, 2026

Fixes #8276

Description

  • Added a new argument apply_inverse_to_pred. Defaults to True to preserve backward compatibility. When set to False, it skips the inverse transformation step and aggregates the model predictions directly.
  • Added a new unit test to simulate a classification task with spatial augmentation, verifying that the aggregation works correctly without spatial inversion.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 22, 2026

📝 Walkthrough

Walkthrough

Adds a boolean parameter apply_inverse_to_pred (default True) to TestTimeAugmentation and stores it as self.apply_inverse_to_pred. _check_transforms now iterates transforms to collect randomizable ones, requires at least one randomizable transform, and—when apply_inverse_to_pred is True—flags non-invertible random transforms, emitting a combined warning listing issues. __call__ conditionally inverts predictions only when apply_inverse_to_pred is True; when False it collects raw predictions directly to support non-spatial outputs. Tests: UNet invocation updated strides=(2, 2)strides=(2,), and a new test_non_spatial_output validates behavior with apply_inverse_to_pred=False, checking mean/std aggregation and return_full_data=True shapes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and specifically describes the main change: enabling TestTimeAugmentation to handle non-spatial predictions.
Description check ✅ Passed Description covers the new argument, its default value, test additions, and change type. All required template sections are present and filled.
Linked Issues check ✅ Passed Changes fully address issue #8276: adds apply_inverse_to_pred parameter to skip inverse transforms, enabling TTA for non-spatial outputs like classification logits.
Out of Scope Changes check ✅ Passed UNet stride change from (2,2) to (2,) in test file is scoped to test setup and supports the new test method validating non-spatial behavior.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@ericspod
Copy link
Member

hi @ytl0623 thanks for this change, I think it's fine in principle. The _check_transforms method should be changed to account for when the new argument is False, in which case it doesn't need to check for invertibility of transforms. I noticed other issues with the original version of this method so I'd propose something like the following (which I haven't tested):

def _check_transforms(self):
    """Should be at least 1 random transform, and all random transforms should be invertible."""
    transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
    warns=[]
    randoms=[]
    for idx, t in transforms:
        if isinstance(t, Randomizable):
            randoms.append(t)
            if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
                warns.append(f"Transform {idx} (type {type(t).__name__}) not invertible.")
                
    if len(randoms)==0:
        warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")
        
    if len(warns)>0:
        warnings.warn("TTA has encountered issues with the given transforms:"+"\n  ".join(warns))

Please check this logic, it might be that we need to check all transforms for invertibility whether they're random or not, but what I have here is equivalent to the original.

@ytl0623
Copy link
Contributor Author

ytl0623 commented Jan 23, 2026

Hi @ericspod, thanks for the suggestion!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@monai/data/test_time_augmentation.py`:
- Around line 67-71: The docstring for test-time augmentation (function/class
using parameter names transform, batch_size, and apply_inverse_to_pred)
incorrectly states "All random transforms must be of type InvertibleTransform";
update the docstring (and the transform type hint if present) to reflect that
non-invertible random transforms are allowed when apply_inverse_to_pred=False
and only need to be invertible when apply_inverse_to_pred=True; change the
wording in both occurrences (the block around the transform description and the
later paragraph at lines ~115-118) to describe this conditional requirement and,
if applicable, broaden the transform type hint to accept non-invertible
Randomizable types when apply_inverse_to_pred is False.
- Around line 174-175: The warning message built from the local variable warns
is missing a newline after the colon and does not set a stacklevel, so update
the warnings.warn call to prepend a newline (e.g., "TTA has encountered issues
with the given transforms:\n  " + "\n  ".join(warns)) and pass an appropriate
stacklevel (e.g., stacklevel=2) so user stack traces point to the caller; locate
and modify the warnings.warn(...) invocation that uses the warns list in
test_time_augmentation.py.
- Around line 208-213: The non-inverse branch currently returns raw predictions
and skips all Invertd post-processing (to_tensor, output_device, post_func);
update the branch so decollated items still go through the same inverter
pipeline (or a post-processing-only path) before extracting self._pred_key.
Concretely, in the else branch replace outs.extend([i[self._pred_key] for i in
decollate_batch(b)]) with code that calls self.inverter on each
PadListDataCollate.inverse(i) (or calls an Invertd method/flag that runs only
to_tensor/output_device/post_func but not spatial inverse) and then extracts
[self._pred_key]; ensure the call honors to_tensor, output_device and post_func
parameters so behavior matches the apply_inverse_to_pred=True path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Generalize TestTimeAugmentation to non-spatial predictions

2 participants