Roll seam-crossing longitudes in the downscaling data layer#1236
Conversation
…1234) First in a 5-PR stack adding support for longitude domains that cross the 0/360 prime meridian in downscaling. This standalone hardening PR moves expert grid-compatibility validation into the predictor constructor so every construction path is protected, not just the config-build path: only the primary expert's coordinates are used for input prep and output coords, so an expert built on a mismatched grid would otherwise silently downscale onto the wrong grid. Changes: - `fme.downscaling.predictors.serial_denoising`: move `_validate_experts_compatible` from `DenoisingMoEConfig.build` into `DenoisingMoEPredictor.__init__`, so it holds for `build`, `from_state`, and future callers (e.g. `with_rolled_lon`). - `fme.downscaling.test_models`: add `test_denoising_moe_predictor_rejects_mismatched_expert_grids`, constructing the predictor directly with mismatched-grid experts and asserting it raises. - [x] Tests added - [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated Base: `main` ### Stack | PR | Head → Base | Title | |----|-------------|-------| | [#1234](#1234) | `refactor/moe-validate-experts-init` → `main` | Validate expert grid compatibility in `DenoisingMoEPredictor.__init__` | | [#1235](#1235) | `feature/lon-roll-primitives` → PR1 | Add longitude roll primitives | | [#1236](#1236) | `feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in the data layer | | [#1237](#1237) | `feature/lon-roll-model` → PR3 | Add with_rolled_lon to models | | [#1238](#1238) | `feature/lon-roll-integration` → PR4 | Roll the model in inference/predict/evaluator |
12baad6 to
1df468b
Compare
) PR 2 of 5 in the prime-meridian longitude stack. Adds the pure coordinate/data rolling utilities needed to re-express a global grid in a seam-crossing domain's convention. These have no production callers yet — later PRs wire them into the data and model layers — so they are reviewable in isolation with full unit coverage. The interval-based roll only triggers when an interval actually crosses the seam (`start < 0` or `stop > 360`), so in-range intervals are a no-op and non-global grids are left untouched. Primitives overview (PR #1235) These primitives are always used as a pair: find_roll_anchor (or find_roll_anchor_from_interval) computes the roll amount once; callers pass it to all subsequent roll_lon_coords and roll_lon_data so coordinates and field tensors shift by the same amount. Two downstream pathways use them: - Dataset load — rolls each loaded grid into the user's configured lon_extent convention (PR #1236) - Model setup — rolls the model's fine grid to match the incoming coarse batch's convention (PR #1237) Changes: - `fme.downscaling.data.utils`: add `ClosedInterval.finite_values`, `_requires_lon_roll`, `coords_require_lon_roll`, `find_roll_anchor`, `find_roll_anchor_from_interval`, `roll_lon_coords`, `roll_lon_data`, and private helpers `_validate_rollable_lon` and `_validate_monotonic_lon`. - `roll_lon_coords` (1-D coordinate tensor) and `roll_lon_data` (N-D field tensor) form a parallel pair: both apply the same roll amount, but `roll_lon_coords` also remaps values to keep the result monotonically increasing, while `roll_lon_data` is a pure cyclic shift. Callers pre-compute the roll amount once via `find_roll_anchor` and pass it to both. - `roll_latlon_coords` is not included here; it operates on a `LatLonCoordinates` struct rather than a raw tensor and belongs in the PR that first uses it. - `fme.downscaling.data` (`__init__`): export the new roll helpers. - `fme.downscaling.data.test_utils`: unit tests for roll amounts, seam-crossing conventions, round-trip invertibility, non-global/non-uniform rejection, and invalid input validation. - [x] Tests added - [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated Base: `refactor/moe-validate-experts-init` (PR 1) ### Stack | PR | Head → Base | Title | |----|-------------|-------| | [#1234](#1234) | `refactor/moe-validate-experts-init` → `main` | Validate expert grid compatibility in `DenoisingMoEPredictor.__init__` | | [#1235](#1235) | `feature/lon-roll-primitives` → PR1 | Add longitude roll primitives | | [#1236](#1236) | `feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in the data layer | | [#1237](#1237) | `feature/lon-roll-model` → PR3 | Add with_rolled_lon to models | | [#1238](#1238) | `feature/lon-roll-integration` → PR4 | Roll the model in inference/predict/evaluator |
Apply the roll primitives so the data layer can subset longitude domains that cross the 0/360 seam: - HorizontalSubsetDataset now rolls its data and coordinates to the requested interval's convention instead of raising NotImplementedError on wraparound; in-range intervals resolve to a zero roll and behave as before. - StaticInputs.roll rolls static fields and their lon coordinates to match. - BatchItemDatasetAdapter exposes latlon_coordinates, and GriddedData / PairedGriddedData carry coarse_latlon_coords (populated in config) so the coarse grid convention is available to consumers. Adds tests for seam-crossing subsetting (both negative and >360 conventions) and StaticInputs.roll.
adjust_fine_coord_range received unrolled (0-360) longitude tensors, so for an interval like (-16, 30) the coarse_min snapped to ~0 (first 0-360 coord >= -16) instead of -16. This made the computed fine extent too narrow (0-30° instead of -16-30°), causing a scale-factor mismatch error when the paired dataset validated fine vs coarse dimensions. Fix: roll both the coarse and fine lon tensors into the interval's convention before calling adjust_fine_coord_range. The fine anchor is placed one half-coarse-spacing before lon_start so that adjust_fine_coord_range can access the fine half-cells below the first coarse grid point. For non-crossing domains (coarse_roll=0) the roll is a no-op and behaviour is unchanged.
7455589 to
806b5cd
Compare
|
Claude flagged an edge case where longitude intervals crossing 180 should require a roll in the case where lon range is (-180, 180) but not if the max lon is 360. A check on max longitude coord could differentiate these scenarios. |
AnnaKwa
left a comment
There was a problem hiding this comment.
Mostly LGTM, could you expand on the data loader test and if possible do a bit of refactoring so the tests of HorizontalSubsetDataset are more clearly associated with the code being tested?
| assert ( | ||
| batch.coarse.data["var0"].shape[-1] * scale_factor | ||
| == batch.fine.data["var0"].shape[-1] | ||
| ) |
There was a problem hiding this comment.
Can you expand this test to have it also check that the coordinates and data values are correct for a rolled data batch and coords? If the data values are just the longitudes this should be straightforward.
| full_fine_coord=rolled_fine_lon, | ||
| ) | ||
|
|
||
| dataset_fine_subset = HorizontalSubsetDataset( |
There was a problem hiding this comment.
It looks like HorizontalSubsetDataset is what is tested in the new tests; is there a small refactor that can be done here to isolate the code that produces this object into a function or method? That would make it a lot clearer what parts of the code are being tested in the additions to test_datasets.py, and easier to debug if needed in the future.
| variable_metadata=variable_metadata, | ||
| all_times=all_times, | ||
| fine_coords=get_latlon_coords_from_properties(properties_fine), | ||
| coarse_extent_latlon_coords=dataset_coarse_subset.latlon_coordinates, |
There was a problem hiding this comment.
Is this used in a later PR?
There was a problem hiding this comment.
Yes, I'll move that to the PR it's used in.
…to wt/roll-lon-data-layer
| # mod 360 (e.g. original 337.5 -> coord -22.5); see the value-level check in | ||
| # test_build_aligned_subset_pair_preserves_scale_factor_across_seam. | ||
| for grid in (batch.coarse, batch.fine): | ||
| lon = grid.latlon_coordinates.lon[0].cpu() # batch members are identical |
There was a problem hiding this comment.
Can you also check that the longitude values are consistent with the original lon_extent?
PR 4 of 5 in the prime-meridian longitude stack (PRs 1–3 now merged to main). Lets a model re-express its grid in a seam-crossing coarse domain's longitude convention while sharing the trained network weights, so a single checkpoint can generate over a domain expressed west of 0 or east of 360. Changes: - `fme.downscaling.models.DiffusionModel.with_rolled_lon`: rebuild the model through its constructor with `full_fine_coords` and `static_inputs` rolled to match the coarse grid, anchored on the western coarse-cell edge so the fine grid stays aligned to whole coarse cells; returns `self` when no roll is needed. Inference-only (rebuilding re-wraps the module under torch distributed). - `fme.downscaling.predictors.serial_denoising.DenoisingMoEPredictor.with_rolled_lon`: roll every expert (preserving the shared-grid invariant) and rebuild so the sigma dispatcher is reconstructed from the rolled experts. - `fme.downscaling.data` exports `roll_lon_coords` for the model layer. - `fme.downscaling.test_models`: tests for no-roll passthrough, coord shifting with shared weights (including value-level checks that coords and static data roll together, and that a double roll is a no-op), and coarse-cell alignment for a seam-crossing domain. MoE rolling tests live in `test_serial_denoising` next to the existing grid-validation test. - Test cleanup: shared `cell_centered_coordinate` helper in `test_utils` replaces per-file midpoint-coordinate constructions (`test_models`, `test_config`); removed a test and helper in `test_models`/`test_serial_denoising` duplicated from #1234. - [x] Tests added - [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated Base: `main` (PRs 1–3 of the stack merged) ### Stack | PR | Head → Base | Title | Status | |----|-------------|-------|--------| | [#1234](#1234) | `refactor/moe-validate-experts-init` → `main` | Validate expert grid compatibility in `DenoisingMoEPredictor.__init__` | merged | | [#1235](#1235) | `feature/lon-roll-primitives` → `main` | Add longitude roll primitives | merged | | [#1236](#1236) | `feature/lon-roll-data-layer` → `main` | Roll seam-crossing longitudes in the data layer | merged | | [#1237](#1237) | `feature/lon-roll-model` → `main` | Add with_rolled_lon to models | this PR | | [#1238](#1238) | `feature/lon-roll-integration` → PR4 | Roll the model in inference/predict/evaluator | open |
PR 3 of 5 in the prime-meridian longitude stack. Applies the roll primitives (PR 2) in the data layer so a longitude interval that crosses the 0/360 seam can be subset instead of raising
NotImplementedError. In-range intervals resolve to a zero roll and behave exactly as before.Changes:
fme.downscaling.data.datasets.HorizontalSubsetDataset: roll data and coordinates into the requested interval's convention rather than raising on wraparound.fme.downscaling.data.config: extract_build_aligned_subset_pair, which rolls coarse and fine lon coords into the extent's convention (_roll_lons_to_extent_convention) beforeadjust_fine_coord_range, so fine/coarse subselection stays aligned across the seam.fme.downscaling.data.static.StaticInputs.roll: roll static fields and their lon coordinates to match.fme.downscaling.data.test_config,fme.downscaling.data.test_datasets,fme.downscaling.data.test_static: tests for seam-crossing subsetting (negative and >360 conventions), fine/coarse scale-factor preservation across the seam (even and odd downscale factors), end-to-end paired loader with a seam-crossing extent, andStaticInputs.roll.Note: surfacing the coarse grid convention on
GriddedData/PairedGriddedData(coarse_latlon_coords) was deferred to the integration PR after review discussion.Base:
feature/lon-roll-primitives(PR 2)Stack
refactor/moe-validate-experts-init→mainDenoisingMoEPredictor.__init__feature/lon-roll-primitives→ PR1feature/lon-roll-data-layer→ PR2feature/lon-roll-model→ PR3feature/lon-roll-integration→ PR4