Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions fme/downscaling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
PairedBatchData,
StaticInputs,
adjust_fine_coord_range,
coords_require_lon_roll,
find_roll_anchor,
load_coords_from_path,
roll_latlon_coords,
)
from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate
from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector
Expand Down Expand Up @@ -747,6 +750,58 @@ def metadata(self):
else 0,
)

def _lon_roll_amount(self, coarse_lon: torch.Tensor) -> tuple[int, float]:
"""
Number of positions to roll the fine grid (and the lon_start it aligns to)
so the fine cells stay aligned to coarse_lon's coarse cells.

coarse_lon is the actual coarse domain grid, so it already carries the
convention to align to. The roll is anchored on the western coarse-cell
*edge* (half a coarse cell below coarse_lon.min(), which is a cell *center*)
so the fine grid rolls by a whole number of coarse cells and its cells stay
aligned to the coarse cells. Anchoring on the center instead would roll by an
extra downscale_factor // 2 fine points, splitting the boundary coarse cell
across the seam.
"""
lon_start = float(coarse_lon.min())
fine_lon = self.full_fine_coords.lon
fine_spacing = float(fine_lon[1] - fine_lon[0])
western_edge = lon_start - self.downscale_factor * fine_spacing / 2.0
return find_roll_anchor(fine_lon, western_edge), lon_start

def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel":
"""
Return a new model with full_fine_coords and static_inputs rolled to match
coarse_lon's longitude convention, sharing the network weights.

Returns self unchanged when coarse_lon does not cross the prime meridian.
The new model is built through the constructor (rather than a shallow copy)
so its coords are re-validated and derived state is rebuilt fresh; the raw
module is unwrapped and passed so __init__ re-wraps it exactly once.
"""
if not coords_require_lon_roll(coarse_lon):
return self
roll_amount, lon_start = self._lon_roll_amount(coarse_lon)
return DiffusionModel(
config=self.config,
module=self.module.module,
normalizer=self.normalizer,
loss=self.loss,
coarse_shape=self.coarse_shape,
downscale_factor=self.downscale_factor,
sigma_data=self.sigma_data,
full_fine_coords=roll_latlon_coords(
self.full_fine_coords, roll_amount, lon_start
),
in_names=self.in_names,
out_names=self.out_names,
static_inputs=(
self.static_inputs.roll(roll_amount, lon_start)
if self.static_inputs is not None
else None
),
)


@dataclasses.dataclass
class _CheckpointModelConfigSelector:
Expand Down
20 changes: 20 additions & 0 deletions fme/downscaling/predictors/serial_denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,26 @@ def static_inputs(self) -> StaticInputs | None:
def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates:
return self._primary.get_fine_coords_for_batch(batch)

def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DenoisingMoEPredictor":
"""New predictor with every expert's coords rolled to match coarse_lon.

All experts are rolled (not just the primary) so the shared-grid invariant
enforced in __init__ still holds -- nothing relies on the non-primary
experts' coordinates being left unrolled. Rebuilt through __init__ so
_dispatch_module is reconstructed from the rolled experts. Returns self
unchanged when no roll is needed.
"""
rolled = [expert.with_rolled_lon(coarse_lon) for expert in self._experts]
if all(r is e for r, e in zip(rolled, self._experts)):
return self
return DenoisingMoEPredictor(
experts=rolled,
sigma_ranges=self._sigma_ranges,
num_diffusion_generation_steps=self._num_diffusion_generation_steps,
churn=self._churn,
expert_renames=self._expert_renames,
)

@torch.no_grad()
def generate(
self,
Expand Down
161 changes: 161 additions & 0 deletions fme/downscaling/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,167 @@ def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]):
return full_fine_coords, static_inputs


def test_with_rolled_lon_no_roll_returns_same():
"""with_rolled_lon returns the original model when no roll is needed."""
coarse_shape = (8, 16)
fine_shape = (16, 32)
static_inputs = make_static_inputs(fine_shape)
model = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=2,
full_fine_coords=static_inputs.coords,
static_inputs=static_inputs,
)
coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1])
assert model.with_rolled_lon(coarse_lon) is model


def test_with_rolled_lon_shifts_coords_and_shares_weights():
"""with_rolled_lon: new model with rolled coords, shared network weights."""
coarse_shape = (8, 16)
fine_shape = (16, 32)
full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape)
model = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=2,
full_fine_coords=full_fine_coords,
static_inputs=static_inputs,
)

coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32)
rolled = model.with_rolled_lon(coarse_lon)

# Reconstruction wraps a fresh module around the SAME raw weights.
assert rolled.module is not model.module
assert next(rolled.module.parameters()) is next(model.module.parameters())
assert not torch.equal(rolled.full_fine_coords.lon, model.full_fine_coords.lon)
assert torch.all(rolled.full_fine_coords.lon[1:] > rolled.full_fine_coords.lon[:-1])
assert rolled.full_fine_coords.lon[0].item() < 0
assert rolled.static_inputs is not None
# Compare against model.static_inputs (on-device) rather than the CPU-side original
assert not torch.equal(
rolled.static_inputs.fields[0].data, model.static_inputs.fields[0].data
)


def test_with_rolled_lon_is_idempotent():
"""Rolling an already-rolled model with the same domain is a no-op.

Guards against accidental double-rolling: the second roll resolves to 0
(full rotation), so the twice-rolled model has identical coords and static
inputs to the once-rolled one.
"""
coarse_shape = (8, 16)
fine_shape = (16, 32)
full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape)
model = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=2,
full_fine_coords=full_fine_coords,
static_inputs=static_inputs,
)

coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32)
rolled = model.with_rolled_lon(coarse_lon)
twice = rolled.with_rolled_lon(coarse_lon)

assert torch.equal(twice.full_fine_coords.lon, rolled.full_fine_coords.lon)
assert rolled.static_inputs is not None and twice.static_inputs is not None
assert torch.equal(
twice.static_inputs.fields[0].data, rolled.static_inputs.fields[0].data
)


def test_roll_diffusion_model_keeps_fine_aligned_to_coarse_cells():
"""A seam-crossing domain must roll the fine grid by whole coarse cells.

The roll is anchored on the western coarse-cell edge, not its center. If it
anchored on the center it would roll an extra downscale_factor // 2 fine
points, leaving no fine margin below the western coarse cell -- which makes
get_fine_coords_for_batch raise -- and splitting that cell across the seam.
"""
coarse_shape = (4, 8)
fine_shape = (16, 32)
factor = 4
full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape)
model = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=factor,
full_fine_coords=full_fine_coords,
static_inputs=static_inputs,
)

# Four of the eight global 45-degree coarse cells, crossing the 0/360 seam and
# expressed in negative convention (physically 292.5, 337.5 and 22.5, 67.5).
# Coarse-lat centers [6, 10] are interior, leaving fine margin above and below.
coarse_lat = [6.0, 10.0]
coarse_lon = [-67.5, -22.5, 22.5, 67.5]
batch = make_batch_data(
(1, len(coarse_lat), len(coarse_lon)), coarse_lat, coarse_lon
)

rolled = model.with_rolled_lon(torch.tensor(coarse_lon, dtype=torch.float32))
# Anchoring on the cell center would leave no margin and raise here.
fine_coords = rolled.get_fine_coords_for_batch(batch)

# Each coarse cell is covered by exactly `factor` fine cells whose mean is the
# coarse-cell center -- i.e. the fine grid stayed aligned to the coarse cells.
recentered = fine_coords.lon.reshape(len(coarse_lon), factor).mean(dim=1).cpu()
assert torch.allclose(recentered, torch.tensor(coarse_lon), atol=1e-3)


def test_denoising_moe_predictor_with_rolled_lon_rolls_all_experts():
"""with_rolled_lon rolls every expert (keeping the shared-grid invariant)."""
from fme.downscaling.predictors.serial_denoising import DenoisingMoEPredictor

coarse_shape = (8, 16)
fine_shape = (16, 32)
full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape)

expert0 = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=2,
full_fine_coords=full_fine_coords,
static_inputs=static_inputs,
)
expert1 = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=2,
full_fine_coords=full_fine_coords,
static_inputs=static_inputs,
)
predictor = DenoisingMoEPredictor(
experts=[expert0, expert1],
sigma_ranges=[(0.0, 0.5), (0.5, 1.0)],
num_diffusion_generation_steps=2,
churn=0.0,
)

coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32)
rolled = predictor.with_rolled_lon(coarse_lon)

# Every expert is a new (rolled) object; _primary stays _experts[0].
assert rolled._primary is rolled._experts[0]
for rolled_expert, original, source in zip(
rolled._experts, predictor._experts, [expert0, expert1]
):
assert rolled_expert is not original
# Coords are rolled...
assert rolled_expert.full_fine_coords.lon[0].item() < 0
# ...but the raw network weights are still shared (fresh wrapper).
assert next(rolled_expert.module.parameters()) is next(
source.module.parameters()
)
# The sigma dispatcher is rebuilt from the rolled experts, consistent with
# _experts (not left pointing at any pre-roll module).
for entry, rolled_expert in zip(rolled._dispatch_module._entries, rolled._experts):
assert entry[2] is rolled_expert.module

# No-roll case returns self
non_neg_lon = torch.tensor([0.0, 5.0, 10.0, 15.0], dtype=torch.float32)
assert predictor.with_rolled_lon(non_neg_lon) is predictor


def test_denoising_moe_predictor_rejects_mismatched_expert_grids():
"""Experts on different grids are rejected at construction (shared-grid)."""
from fme.downscaling.predictors.serial_denoising import DenoisingMoEPredictor
Expand Down