Skip to content

Add clip_latent_global_means option to conditional SFNO#1230

Merged
mcgibbon merged 7 commits into
mainfrom
feature/clip-latent-global-means
Jun 11, 2026
Merged

Add clip_latent_global_means option to conditional SFNO#1230
mcgibbon merged 7 commits into
mainfrom
feature/clip-latent-global-means

Conversation

@mcgibbon

@mcgibbon mcgibbon commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Adds a clip_latent_global_means option to conditional SFNO that bounds the per-channel spatial mean of the post-encoder latent at inference to the range observed during the most recent training epoch. During training the model tracks a per-channel min/max envelope of that mean; during eval, the latent is shifted by clamp(mean) - mean so the mean falls within the envelope (a no-op when it already does). The envelope is reset at the start of each training epoch via a new set_epoch hook on TrainStepperABC. Defaults to False so existing models are unaffected and the existing regression checkpoint still matches.

Changes:

  • fme.core.generics.train_stepper.TrainStepperABC.set_epoch — new no-op default hook for per-epoch in-module state.

  • fme.core.generics.trainer.Trainer.train_one_epoch — invokes stepper.set_epoch on fresh-epoch boundaries (skipped on mid-epoch resume so partial-epoch state continues to accumulate).

  • fme.ace.stepper.single_module.Stepper.set_epoch / TrainStepper.set_epoch and fme.coupled.stepper.CoupledStepper.set_epoch / CoupledTrainStepper.set_epoch — walk submodules and call request_latent_global_mean_envelope_reset where present, so model components opt in without coupling the stepper to their internals.

  • fme.core.models.conditional_sfno.sfnonet.SFNONetConfig.clip_latent_global_means and SphericalFourierNeuralOperatorNet — buffers _gm_min / _gm_max, lazy reset flag, distributed reduce of per-batch min/max during training, eval-mode clip-residual shift, and request_latent_global_mean_envelope_reset.

  • fme.ace.registry.stochastic_sfno.NoiseConditionedSFNOBuilder.clip_latent_global_means — exposes the option in the public builder.

  • Tests in fme.core.models.conditional_sfno.test_sfnonet covering envelope initialization, training-mode tracking, eval-mode shift with a tight envelope, parity with clip_latent_global_means=False when the envelope is uninitialized, lazy reset, and a parallel-marked test that the envelope is synchronized across data-parallel ranks.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

mcgibbon and others added 4 commits June 5, 2026 16:57
Adds a TrainStepperABC.set_epoch(epoch) hook with a no-op default and
wires it from the trainer at fresh-epoch boundaries (mid-epoch resume
preserves in-module state so partial-epoch accumulators continue from
where they left off).

Stepper and CoupledStepper implement set_epoch by walking submodules
and invoking request_latent_global_mean_envelope_reset where present,
giving model components a way to reset per-epoch in-module statistics
without coupling the stepper to model internals.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When enabled, the per-channel spatial mean of the post-encoder latent
is tracked during training and, in eval, the latent is shifted so that
mean falls within the observed envelope (no-op when the mean is
already inside it). Bounds the global-mean of the latent the
transformer blocks see at inference to the range observed in training.

The envelope is reset at the start of each training epoch (lazily, on
the next training-mode forward) via
request_latent_global_mean_envelope_reset, which the stepper invokes
through the TrainStepperABC.set_epoch hook.

Exposed as a single clip_latent_global_means: bool option on
SFNONetConfig and NoiseConditionedSFNOBuilder; defaults to False so
existing models are unaffected.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@mcgibbon mcgibbon requested a review from Arcomano1234 June 10, 2026 16:38
elif torch.isfinite(self._gm_max).all():
clipped = torch.clamp(global_means, min=self._gm_min, max=self._gm_max)
# Shift x by the clip residual so its per-channel spatial
# mean falls within the envelope observed during training.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this is in the weeds a little but observed during training isn't exactly true. The clipped values are just during the training of the particular epoch that corresponds to the ckpt being used right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it's correct, but I agree it's clearer to specify "during last-epoch training".

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I was confused at first before reading the code because I was worried about large gradients at the beginning of training but then realized it was isolated to a single epoch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clarify it.

torch.testing.assert_close(out_clip, out_no_clip)


def test_clip_latent_global_means_eval_shifts_when_outside_envelope():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This test is a little misleading. All it does is checks if out is finite, which I assume would be true even if the clipping did not work as intended. I suggest making sure things are clipped as expected or remove the test as I think there is quite of bit in terms of overlap between this test and the others added.

Comment thread fme/core/models/conditional_sfno/sfnonet.py
Comment thread fme/core/generics/train_stepper.py

@Arcomano1234 Arcomano1234 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments, this looks good to me.

@mcgibbon mcgibbon merged commit 347a9bc into main Jun 11, 2026
7 checks passed
@mcgibbon mcgibbon deleted the feature/clip-latent-global-means branch June 11, 2026 19:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants