Add clip_latent_global_means option to conditional SFNO#1230
Conversation
Adds a TrainStepperABC.set_epoch(epoch) hook with a no-op default and wires it from the trainer at fresh-epoch boundaries (mid-epoch resume preserves in-module state so partial-epoch accumulators continue from where they left off). Stepper and CoupledStepper implement set_epoch by walking submodules and invoking request_latent_global_mean_envelope_reset where present, giving model components a way to reset per-epoch in-module statistics without coupling the stepper to model internals. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When enabled, the per-channel spatial mean of the post-encoder latent is tracked during training and, in eval, the latent is shifted so that mean falls within the observed envelope (no-op when the mean is already inside it). Bounds the global-mean of the latent the transformer blocks see at inference to the range observed in training. The envelope is reset at the start of each training epoch (lazily, on the next training-mode forward) via request_latent_global_mean_envelope_reset, which the stepper invokes through the TrainStepperABC.set_epoch hook. Exposed as a single clip_latent_global_means: bool option on SFNONetConfig and NoiseConditionedSFNOBuilder; defaults to False so existing models are unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| elif torch.isfinite(self._gm_max).all(): | ||
| clipped = torch.clamp(global_means, min=self._gm_min, max=self._gm_max) | ||
| # Shift x by the clip residual so its per-channel spatial | ||
| # mean falls within the envelope observed during training. |
There was a problem hiding this comment.
Nit: this is in the weeds a little but observed during training isn't exactly true. The clipped values are just during the training of the particular epoch that corresponds to the ckpt being used right?
There was a problem hiding this comment.
Technically it's correct, but I agree it's clearer to specify "during last-epoch training".
There was a problem hiding this comment.
I guess I was confused at first before reading the code because I was worried about large gradients at the beginning of training but then realized it was isolated to a single epoch
| torch.testing.assert_close(out_clip, out_no_clip) | ||
|
|
||
|
|
||
| def test_clip_latent_global_means_eval_shifts_when_outside_envelope(): |
There was a problem hiding this comment.
Nit: This test is a little misleading. All it does is checks if out is finite, which I assume would be true even if the clipping did not work as intended. I suggest making sure things are clipped as expected or remove the test as I think there is quite of bit in terms of overlap between this test and the others added.
…est, add parallel envelope test
Arcomano1234
left a comment
There was a problem hiding this comment.
Thanks for addressing the comments, this looks good to me.
Adds a
clip_latent_global_meansoption to conditional SFNO that bounds the per-channel spatial mean of the post-encoder latent at inference to the range observed during the most recent training epoch. During training the model tracks a per-channel min/max envelope of that mean; during eval, the latent is shifted byclamp(mean) - meanso the mean falls within the envelope (a no-op when it already does). The envelope is reset at the start of each training epoch via a newset_epochhook onTrainStepperABC. Defaults to False so existing models are unaffected and the existing regression checkpoint still matches.Changes:
fme.core.generics.train_stepper.TrainStepperABC.set_epoch— new no-op default hook for per-epoch in-module state.fme.core.generics.trainer.Trainer.train_one_epoch— invokesstepper.set_epochon fresh-epoch boundaries (skipped on mid-epoch resume so partial-epoch state continues to accumulate).fme.ace.stepper.single_module.Stepper.set_epoch/TrainStepper.set_epochandfme.coupled.stepper.CoupledStepper.set_epoch/CoupledTrainStepper.set_epoch— walk submodules and callrequest_latent_global_mean_envelope_resetwhere present, so model components opt in without coupling the stepper to their internals.fme.core.models.conditional_sfno.sfnonet.SFNONetConfig.clip_latent_global_meansandSphericalFourierNeuralOperatorNet— buffers_gm_min/_gm_max, lazy reset flag, distributed reduce of per-batch min/max during training, eval-mode clip-residual shift, andrequest_latent_global_mean_envelope_reset.fme.ace.registry.stochastic_sfno.NoiseConditionedSFNOBuilder.clip_latent_global_means— exposes the option in the public builder.Tests in
fme.core.models.conditional_sfno.test_sfnonetcovering envelope initialization, training-mode tracking, eval-mode shift with a tight envelope, parity withclip_latent_global_means=Falsewhen the envelope is uninitialized, lazy reset, and aparallel-marked test that the envelope is synchronized across data-parallel ranks.Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated