Skip to content

[Fix] Support fp32 param preservation during FSDP save and load#1626

Closed
HAOCHENYE wants to merge 1 commit intogh/HAOCHENYE/19/basefrom
gh/HAOCHENYE/19/head
Closed

[Fix] Support fp32 param preservation during FSDP save and load#1626
HAOCHENYE wants to merge 1 commit intogh/HAOCHENYE/19/basefrom
gh/HAOCHENYE/19/head

Conversation

@HAOCHENYE
Copy link
Copy Markdown
Collaborator

@HAOCHENYE HAOCHENYE commented Mar 24, 2026

Stack from ghstack (oldest at bottom):


  • Add fp32_keys_pattern to HFSaveCfg to specify params that should
    be saved in fp32 regardless of the global save dtype
  • Add _fully_shard() to BaseModel which wraps fully_shard() and
    distributes matched params as Replicate DTensors so FSDP ignores them
  • Replace all direct fully_shard() call sites with self._fully_shard()
  • Add _get_save_dtype() to select fp32 per-param at save time
  • Fix load path in _load_same_hf_param to skip shard-offset logic for
    Replicate DTensors (only apply for Shard-placed params)
  • Fix world_mesh property setter bug: use _world_mesh directly
    instead of assigning through the read-only property
  • Fix gradient reduce in MoE.scale_and_reduce_grad to correctly
    identify the Replicate mesh dimension via DTensor placements
  • Preserve requires_grad when distributing params in EP and fp32 paths
  • Configure Qwen3_5_VLTextMoEConfig with fp32 patterns for
    linear_attn.norm.weight and linear_attn.A_log
  • Add test_save_hf_with_mtp to verify round-trip weight preservation

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant