feat(ws1): NativeEmbeddingOp pure-PyTorch ground-truth reference + numerical contract tests#169
Conversation
WS1 ground-truth token-embedding op for issue RL-Align#108 (Qwen3-8B input embedding table, vocab=151936 x hidden=4096, tie_word_embeddings=false): - NativeEmbeddingOp: out = weight[token_ids], a lossless row gather exposing the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path); pure function, no in-place mutation. - register PYTORCH_NATIVE_EMBEDDING in OpBackend and the cuda/rocm/cpu priority maps. - tests/test_embedding.py: bitwise correctness vs direct indexing, dtype paths, non-int64 id tolerance, Axis-A batch invariance (slice + padding), purity, sparse gradient flow to weight, registry dispatch, and a GPU-only real-shape smoke test (vocab=151936, boundary ids). - docs/operators/embedding.md + nav/index wiring.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
✅ Files skipped from review due to trivial changes (3)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds a pure-PyTorch ChangesNativeEmbeddingOp
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| *, | ||
| output_dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| out = F.embedding(token_ids.long(), weight.float()) |
There was a problem hiding this comment.
I’m worried about this weight.float() path: it upcasts the entire embedding table before gathering, so real fp16/bf16 Qwen3-size weights will allocate an extra multi-GB fp32 copy for a tiny lookup. Since this is the only registered embedding backend today, this could make the default fallback OOM in normal GPU use.
There was a problem hiding this comment.
@inaniloquentee Thanks for the advice!
The gather now runs in the weight's native dtype and only the gathered rows are upcast (F.embedding(token_ids.long(), weight).to(output_dtype)), so there's no longer a multi-GB fp32 copy of the full vocab table for a tiny lookup. Since a row gather is lossless (pure indexing, no arithmetic), this is bitwise-identical to the previous path — all 11 tests in tests/test_embedding.py still pass.
Gathering with weight.float() upcast the entire vocab table to fp32 before the lookup, allocating a multi-GB fp32 copy of the Qwen3-8B embedding table just for a tiny row gather and risking OOM on the default fallback path. A row gather is lossless (pure indexing, no arithmetic), so gather in the weight's native dtype and upcast only the gathered rows -- bitwise-identical to the previous path. All 11 tests in tests/test_embedding.py still pass.
KJLdefeated
left a comment
There was a problem hiding this comment.
Overall good to me. After fix the requests, I am happy to approve.
| def test_embedding_gradient_flows_to_weight(): | ||
| op = NativeEmbeddingOp() | ||
| token_ids = _rand_ids((2, 4), seed=7, vocab=10) # small vocab -> some unused rows | ||
| weight = _rand_weight(vocab=10, seed=7).requires_grad_(True) | ||
| op.forward_fp32(token_ids, weight).sum().backward() | ||
|
|
||
| assert torch.isfinite(weight.grad).all() | ||
| used = torch.unique(token_ids).tolist() | ||
| unused = torch.tensor([i for i in range(10) if i not in used]) | ||
| if len(unused): | ||
| assert torch.equal(weight.grad[unused], torch.zeros_like(weight.grad[unused])) |
There was a problem hiding this comment.
The reduction in the backward pass: ∂L/∂weight is a scatter-add, and every repeated token id (padding, common tokens, anything) accumulates into the same weight.grad row. On CUDA, F.embedding backward does that accumulation with atomic adds, and atomic-add order is nondeterministic, so the weight gradient is not bitwise reproducible on GPU whenever ids collide. PyTorch documents embedding backward as a nondeterministic CUDA op for exactly this reason.
Either set it in the op's golden path / a fixture, or state explicitly in embedding.md under Known Limitations that the GPU backward is bitwise-reproducible only with deterministic algorithms enabled.
# forward_fp32 is the backward golden source — its gradient must be
# reproducible. F.embedding backward is an atomic scatter-add on CUDA
# (nondeterministic under id collisions) unless this is set.
torch.use_deterministic_algorithms(True)…nism Backward (∂L/∂weight) is a scatter-add; repeated token ids accumulate into the same grad row, and on CUDA that uses atomic adds (nondeterministic order), so the weight gradient is not bitwise reproducible when ids collide. Since forward_fp32 is the backward golden source: - docs: document the limitation under Known Limitations (embedding.md) - test: enable torch.use_deterministic_algorithms(True) in the gradient test and assert grad is bitwise identical across two independent backward passes
All 11 tests in tests/test_embedding.py still pass. Thanks for the careful review! |
|
LGTM. |
KJLdefeated
left a comment
There was a problem hiding this comment.
Overall good to me. After fix the requests, I am happy to approve.
|
Please resolve CI error and the code conflicts first. Thanks. |
Flink-ddd
left a comment
There was a problem hiding this comment.
This PR elegantly resolves both the OOM hazard (by deferring the upcast) and the non-deterministic scatter-add edge case on CUDA. The implementation and the test coverage are rock solid.
Since the major functional edge cases have been cleared by the team, there is only one architectural standardization left to align this with the rest of the WS1 reference suite.
| is *independent* from the lm_head weight (``tie_word_embeddings=false``). | ||
| """ | ||
|
|
||
| def __init__(self) -> None: |
There was a problem hiding this comment.
same, Add torch.nn.Module to the class inheritance and initialize super().init().
Delete the def call(self, ...):
Summary
Adds the pure-PyTorch ground-truth reference op for the token embedding — the
input layer of the WS1 batch-invariant forward chain — built on top of the numerical
contract defined in #108. Ships the op, its registry wiring, docs, and an 11-case test
suite that pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B
per-dtype path), plus a GPU-only smoke test at the real Qwen3-8B table dims.
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
on how many tokens share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time(small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
lossless row gather — no reduction, no fp32 accumulation — so the dtype output is
bitwise equal to direct indexing at every dtype. No tolerance window is needed.
Motivation / Context
#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The first stage of the Qwen3-8B stack maps integer token
ids to their hidden-state rows:
This PR provides the deterministic fp32 reference path that downstream kernels
(Triton / CUDA / ROCm) will be validated against. For Qwen3-8B the table is the input
embedding
[vocab=151936, hidden=4096]and is independent from the lm_head weight(
tie_word_embeddings=false) — the two weights are not shared.Changes
rl_engine/kernels/ops/pytorch/linear/embedding.py—NativeEmbeddingOpforward()— native-dtype gather, cast the gathered rows back to the weight dtype (Axis-B path)forward_fp32()— native-dtype gather, upcast the result to fp32 (ground-truth / backward golden source)weightrl_engine/kernels/registry.py— registerPYTORCH_NATIVE_EMBEDDINGinOpBackendand add
embeddingdispatch to the cuda / rocm / cpu priority mapstests/test_embedding.py— 11 tests (details below)docs/operators/embedding.md+ nav / index wiringHow this satisfies the #108 contract
forward_fp32()gathers in fp32; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B dtype output is a lossless gather, asserted bitwise against direct indexing (a gather needs no tolerance window)vocab=151936, hidden=4096), exercising boundary ids0andvocab-1; skips when CUDA / GPU memory is unavailableTest Environment
Test Results
The 11 tests cover:
token_ids.shape + (hidden,)).long())token_idsnorweightmutated in place)weight(fp32 autograd = backward golden source), includingsparse-grad: rows never indexed stay exactly zero
embedding→NativeEmbeddingOpvocab=151936, hidden=4096, boundary ids)Checklist
Summary by CodeRabbit
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests