Fix transformers_modeling_backend for pretrained dense models#3772
Fix transformers_modeling_backend for pretrained dense models#3772HosseinKaviani-H wants to merge 4 commits into
Conversation
Three base-backend bugs that make the HF backend produce near-random loss / crash when building or loading real pretrained models (independent of SFT): 1. RoPE inv_freq left uninitialized after meta-init: HF rotary modules compute inv_freq in __init__, but meta-device init + to_empty() zeros it and _init_weights does not recompute it, silently disabling RoPE. Recompute it from each rotary module's rope_init_fn in init_states. 2. TitanDenseModelConfig overrides the HF config: non-None defaults (dim=4096, n_layers=32, n_heads=32, norm_eps=1e-5, rope_theta=10000) were injected over AutoConfig values, forcing the wrong architecture/hyperparameters (e.g. rope_theta 1e6 -> 10000 for Qwen3). Default these to None so the HF config wins; set explicitly only to intentionally override (e.g. debugmodel). 3. Tied embeddings crash FSDP2: tok_embeddings and lm_head share a parameter when tie_word_embeddings is set, but were placed in separate fully_shard groups (FSDP2 forbids this on recent PyTorch). Group tied tok_embeddings/norm/lm_head into one FSDP unit.
|
|
…add state-dict adapter Completes loaded-weight support for pretrained HF dense models: 4. head_dim/intermediate_size were unconditionally derived from hidden_size/num_heads in update_from_config, overriding the HF config. Models like Qwen3 decouple head_dim (128) from hidden_size/num_heads (64), so loading mismatched the projection shapes. Only derive these when dim is explicitly overridden (e.g. debugmodel); otherwise keep the HF config values. 5. model_registry set state_dict_adapter=None, so HF safetensors keys were never mapped to TorchTitan FQNs and initial_load_in_hf loaded nothing. Add HFTransformerStateDictAdapter (model. prefix mapping + tied-embedding handling). Verified: loading Qwen3-0.6B into the 'full' config now gives step-1 loss 3.62, matching native qwen3 (3.63) on the same checkpoint + data.
can you also compare the kl divergence between titan logits and HF logits after the first step? previously @shuhuayu compared titan model vs HF model and the logits difference is within 1e-7. |
yeah, you can try a numerical test similar to |
|
@acisseJZhong @shuhuayu Ran a logit-divergence test (let me know if you want me to push the script as well) -backend vs plain HF from_pretrained on Qwen3-0.6B:
So the backend logits match HF within ~1e-7, in line with @shuhuayu's earlier result. |
thanks. i think the script will be very similar to the existing example so it's fine to not include it. |
Group fields into HF-mapped (default None so AutoConfig's value is kept) vs TorchTitan-only (concrete defaults, no HF equivalent so nothing is overridden), with comments explaining why -- addresses review question on multiple_of.
Fix for #3775
Problem
Loading a real pretrained dense model (e.g. Qwen3-0.6B) into the HF backend produced near-random loss, and on recent PyTorch the backend crashed before training. The HF safetensors load correctly (verified byte-for-byte): the bugs are in how the backend builds/initializes the model, not in loading.
With the same Qwen3-0.6B checkpoint + data, step-1 loss was 11.08 (≈ ln(vocab), i.e. random) vs native TorchTitan qwen3's 3.63.
Root causes fixed
zeros it and init_weights only re-inits parameters, silently disabling RoPE. Recompute it from each rotary module's rope_init_fn in
init_states.
rope_theta=10000) were injected over AutoConfig values, forcing the wrong architecture (e.g. rope_theta 1e6 → 10000 for Qwen3).
Default these to None; set explicitly only to intentionally override (e.g. debugmodel).
groups (forbidden by recent FSDP2). Group tied tok_embeddings/norm/lm_head into one unit.
overriding Qwen3's explicit head_dim=128 with 64 and mismatching checkpoint shapes. Only derive when dim is explicitly overridden;
otherwise keep the HF config's values.
loaded nothing. Add HFTransformerStateDictAdapter (model. prefix + tied-embedding handling).
Verification
Loading Qwen3-0.6B into the full config, same checkpoint/data/seed as native (--debug.seed 42 --debug.deterministic):