Skip to content

Fix transformers_modeling_backend for pretrained dense models#3772

Open
HosseinKaviani-H wants to merge 4 commits into
pytorch:mainfrom
HosseinKaviani-H:hf-backend-rope-fix
Open

Fix transformers_modeling_backend for pretrained dense models#3772
HosseinKaviani-H wants to merge 4 commits into
pytorch:mainfrom
HosseinKaviani-H:hf-backend-rope-fix

Conversation

@HosseinKaviani-H

@HosseinKaviani-H HosseinKaviani-H commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Fix for #3775

Problem

Loading a real pretrained dense model (e.g. Qwen3-0.6B) into the HF backend produced near-random loss, and on recent PyTorch the backend crashed before training. The HF safetensors load correctly (verified byte-for-byte): the bugs are in how the backend builds/initializes the model, not in loading.

With the same Qwen3-0.6B checkpoint + data, step-1 loss was 11.08 (≈ ln(vocab), i.e. random) vs native TorchTitan qwen3's 3.63.

Root causes fixed

  1. RoPE inv_freq left uninitialized after meta-device init. HF rotary modules compute inv_freq in init; meta init + to_empty()
    zeros it and init_weights only re-inits parameters, silently disabling RoPE. Recompute it from each rotary module's rope_init_fn in
    init_states.
  2. TitanDenseModelConfig overrode the HF config. Non-None defaults (dim=4096, n_layers=32, n_heads=32, norm_eps=1e-5,
    rope_theta=10000) were injected over AutoConfig values, forcing the wrong architecture (e.g. rope_theta 1e6 → 10000 for Qwen3).
    Default these to None; set explicitly only to intentionally override (e.g. debugmodel).
  3. Tied embeddings crashed FSDP2. tok_embeddings/lm_head share a parameter under tie_word_embeddings but were in separate fully_shard
    groups (forbidden by recent FSDP2). Group tied tok_embeddings/norm/lm_head into one unit.
  4. head_dim/intermediate_size derived instead of read from the HF config. update_from_config set head_dim = hidden_size // num_heads,
    overriding Qwen3's explicit head_dim=128 with 64 and mismatching checkpoint shapes. Only derive when dim is explicitly overridden;
    otherwise keep the HF config's values.
  5. No state-dict adapter. model_registry set state_dict_adapter=None, so HF weight names were never mapped and initial_load_in_hf
    loaded nothing. Add HFTransformerStateDictAdapter (model. prefix + tied-embedding handling).

Verification

Loading Qwen3-0.6B into the full config, same checkpoint/data/seed as native (--debug.seed 42 --debug.deterministic):

  • Forward parity (step 1, on loaded weights): native 3.631 vs HF 3.623 — identical within bf16 tolerance (was 11.08 before).
  • Training parity (matched LR 3e-4): the curves track step-for-step with no instability:
step native HF
1 3.63 3.62
2 4.06 4.14
5 4.08 4.28
  • (Residual drift is expected — two different model implementations diverge as updates accumulate.)
  • The backend previously crashed at the FSDP tied-param check on any tied model (incl. debugmodel); it now trains.

Three base-backend bugs that make the HF backend produce near-random loss /
crash when building or loading real pretrained models (independent of SFT):

1. RoPE inv_freq left uninitialized after meta-init: HF rotary modules compute
   inv_freq in __init__, but meta-device init + to_empty() zeros it and
   _init_weights does not recompute it, silently disabling RoPE. Recompute it
   from each rotary module's rope_init_fn in init_states.

2. TitanDenseModelConfig overrides the HF config: non-None defaults (dim=4096,
   n_layers=32, n_heads=32, norm_eps=1e-5, rope_theta=10000) were injected over
   AutoConfig values, forcing the wrong architecture/hyperparameters (e.g.
   rope_theta 1e6 -> 10000 for Qwen3). Default these to None so the HF config
   wins; set explicitly only to intentionally override (e.g. debugmodel).

3. Tied embeddings crash FSDP2: tok_embeddings and lm_head share a parameter
   when tie_word_embeddings is set, but were placed in separate fully_shard
   groups (FSDP2 forbids this on recent PyTorch). Group tied
   tok_embeddings/norm/lm_head into one FSDP unit.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 24, 2026
@pytorch-bot

pytorch-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

Workflows were awaiting approval. CI has now been triggered for the ciflow labels on this PR.

…add state-dict adapter

Completes loaded-weight support for pretrained HF dense models:

4. head_dim/intermediate_size were unconditionally derived from
   hidden_size/num_heads in update_from_config, overriding the HF config.
   Models like Qwen3 decouple head_dim (128) from hidden_size/num_heads (64),
   so loading mismatched the projection shapes. Only derive these when dim is
   explicitly overridden (e.g. debugmodel); otherwise keep the HF config values.

5. model_registry set state_dict_adapter=None, so HF safetensors keys were never
   mapped to TorchTitan FQNs and initial_load_in_hf loaded nothing. Add
   HFTransformerStateDictAdapter (model. prefix mapping + tied-embedding handling).

Verified: loading Qwen3-0.6B into the 'full' config now gives step-1 loss 3.62,
matching native qwen3 (3.63) on the same checkpoint + data.
@acisseJZhong

Copy link
Copy Markdown
Contributor

Forward parity (step 1, on loaded weights): native 3.631 vs HF 3.623 — identical within bf16 tolerance (was 11.08 before).

can you also compare the kl divergence between titan logits and HF logits after the first step? previously @shuhuayu compared titan model vs HF model and the logits difference is within 1e-7.

@shuhuayu

Copy link
Copy Markdown
Contributor

Forward parity (step 1, on loaded weights): native 3.631 vs HF 3.623 — identical within bf16 tolerance (was 11.08 before).

can you also compare the kl divergence between titan logits and HF logits after the first step? previously @shuhuayu compared titan model vs HF model and the logits difference is within 1e-7.

yeah, you can try a numerical test similar to scripts/checkpoint_conversion/numerical_tests_qwen3_5.py to check the logits divergence to have more confidence in the numerics.

Comment thread torchtitan/experiments/transformers_modeling_backend/__init__.py Outdated
@HosseinKaviani-H

Copy link
Copy Markdown
Contributor Author

@acisseJZhong @shuhuayu Ran a logit-divergence test (let me know if you want me to push the script as well) -backend vs plain HF from_pretrained on Qwen3-0.6B:

  • Full sequence: KL ≈ 9e-7, cosine = 1.0, top-1 = 100%
  • Last token: KL ≈ 9e-8, max_diff ≈ 2.6e-5

So the backend logits match HF within ~1e-7, in line with @shuhuayu's earlier result.

@shuhuayu

Copy link
Copy Markdown
Contributor

@acisseJZhong @shuhuayu Ran a logit-divergence test (let me know if you want me to push the script as well) -backend vs plain HF from_pretrained on Qwen3-0.6B:

  • Full sequence: KL ≈ 9e-7, cosine = 1.0, top-1 = 100%
  • Last token: KL ≈ 9e-8, max_diff ≈ 2.6e-5

So the backend logits match HF within ~1e-7, in line with @shuhuayu's earlier result.

thanks. i think the script will be very similar to the existing example so it's fine to not include it.

Comment thread torchtitan/experiments/transformers_modeling_backend/model.py
Comment thread torchtitan/experiments/transformers_modeling_backend/parallelize.py Outdated
Comment thread torchtitan/experiments/transformers_modeling_backend/__init__.py Outdated
Comment thread torchtitan/experiments/transformers_modeling_backend/parallelize.py
Hossein Kavianihamedani added 2 commits June 24, 2026 23:11
Group fields into HF-mapped (default None so AutoConfig's value is kept) vs
TorchTitan-only (concrete defaults, no HF equivalent so nothing is overridden),
with comments explaining why -- addresses review question on multiple_of.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants