Skip to content

feat(ws1): NativeEmbeddingOp pure-PyTorch ground-truth reference + numerical contract tests#169

Open
maxiaosong1124 wants to merge 4 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-embedding-pytorch-op
Open

feat(ws1): NativeEmbeddingOp pure-PyTorch ground-truth reference + numerical contract tests#169
maxiaosong1124 wants to merge 4 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-embedding-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference op for the token embedding — the
input layer of the WS1 batch-invariant forward chain — built on top of the numerical
contract defined in #108. Ships the op, its registry wiring, docs, and an 11-case test
suite that pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B
per-dtype path), plus a GPU-only smoke test at the real Qwen3-8B table dims.

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A token's output row must not depend
    on how many tokens share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time
    (small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward path. Embedding is a
    lossless row gather — no reduction, no fp32 accumulation — so the dtype output is
    bitwise equal to direct indexing at every dtype. No tolerance window is needed.

Motivation / Context

#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The first stage of the Qwen3-8B stack maps integer token
ids to their hidden-state rows:

hidden = embedding_table[token_ids]

This PR provides the deterministic fp32 reference path that downstream kernels
(Triton / CUDA / ROCm) will be validated against. For Qwen3-8B the table is the input
embedding [vocab=151936, hidden=4096] and is independent from the lm_head weight
(tie_word_embeddings=false) — the two weights are not shared.

Changes

  • rl_engine/kernels/ops/pytorch/linear/embedding.pyNativeEmbeddingOp
    • forward() — native-dtype gather, cast the gathered rows back to the weight dtype (Axis-B path)
    • forward_fp32() — native-dtype gather, upcast the result to fp32 (ground-truth / backward golden source)
    • Formula: out = weight[token_ids] (via F.embedding(token_ids.long(), weight))
    • Pure function — inputs never mutated in place; output dtype follows weight
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_EMBEDDING in OpBackend
    and add embedding dispatch to the cuda / rocm / cpu priority maps
  • tests/test_embedding.py — 11 tests (details below)
  • docs/operators/embedding.md + nav / index wiring

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path forward_fp32() gathers in fp32; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B dtype output is a lossless gather, asserted bitwise against direct indexing (a gather needs no tolerance window)
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts
Realistic shapes covered GPU-only smoke test at the real Qwen3-8B table dims (vocab=151936, hidden=4096), exercising boundary ids 0 and vocab-1; skips when CUDA / GPU memory is unavailable

Test Environment

OS Ubuntu (kernel 5.15.0-124-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02 (driver 580.76.05)
GPU NVIDIA vGPU-32GB

Test Results

python -m pytest tests/test_embedding.py -v
cf77909025a5752054fb23be11a2ff4e

The 11 tests cover:

  • correctness vs direct indexing across fp32 / bf16 / fp16, asserted bitwise
  • output shape (token_ids.shape + (hidden,))
  • non-int64 id tolerance (cast via .long())
  • Axis-A batch invariance — slice + padding variants, asserted bitwise
  • purity (neither token_ids nor weight mutated in place)
  • gradient flow to weight (fp32 autograd = backward golden source), including
    sparse-grad: rows never indexed stay exactly zero
  • registry dispatch resolves embeddingNativeEmbeddingOp
  • GPU-only real-shape smoke test (Qwen3-8B vocab=151936, hidden=4096, boundary ids)

Checklist

  • Pure-PyTorch reference, no custom extension required
  • Covered at the real Qwen3-8B table dims (vocab=151936, hidden=4096)
  • Axis-A bitwise batch invariance enforced
  • Axis-B lossless-gather dtype path tested (bitwise, no tolerance window)
  • Registered in OpBackend + cuda/rocm/cpu priority maps
  • All 11 tests pass locally

Summary by CodeRabbit

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced a Token Embedding operator that gathers embedding vectors from integer token IDs with CPU/CUDA/ROCm support via a native PyTorch reference backend.
    • Added dtype-aware forward behavior with lossless, bit-exact gather semantics and defined forward vs. fp32 reference paths.
  • Documentation

    • Added operator docs and navigation, including shape/dtype rules and a note on GPU gradient reproducibility for repeated token IDs.
  • Tests

    • Added coverage for correctness across float32/bfloat16/float16, dtype/shape contracts, non-mutation, batch/padding invariance, and gradient expectations (with CUDA smoke validation).

WS1 ground-truth token-embedding op for issue RL-Align#108 (Qwen3-8B input
embedding table, vocab=151936 x hidden=4096, tie_word_embeddings=false):
- NativeEmbeddingOp: out = weight[token_ids], a lossless row gather
  exposing the forward / forward_fp32 dual-path contract (fp32 ground
  truth + dtype-behavior path); pure function, no in-place mutation.
- register PYTORCH_NATIVE_EMBEDDING in OpBackend and the cuda/rocm/cpu
  priority maps.
- tests/test_embedding.py: bitwise correctness vs direct indexing, dtype
  paths, non-int64 id tolerance, Axis-A batch invariance (slice +
  padding), purity, sparse gradient flow to weight, registry dispatch,
  and a GPU-only real-shape smoke test (vocab=151936, boundary ids).
- docs/operators/embedding.md + nav/index wiring.
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f4608ac8-4faa-4fa2-9a60-f2676716c54d

📥 Commits

Reviewing files that changed from the base of the PR and between 2de30a7 and fab4da6.

📒 Files selected for processing (5)
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/embedding.md
  • rl_engine/kernels/registry.py
  • tests/test_embedding.py
✅ Files skipped from review due to trivial changes (3)
  • docs/operators/README.md
  • docs/.nav.yml
  • docs/operators/embedding.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_embedding.py
  • rl_engine/kernels/registry.py

📝 Walkthrough

Walkthrough

Adds a pure-PyTorch NativeEmbeddingOp, registers it as the "embedding" backend on cuda/rocm/cpu, expands the test suite for gather behavior and gradients, and adds operator docs plus navigation links.

Changes

NativeEmbeddingOp

Layer / File(s) Summary
NativeEmbeddingOp and registry wiring
rl_engine/kernels/ops/pytorch/linear/__init__.py, rl_engine/kernels/ops/pytorch/linear/embedding.py, rl_engine/kernels/registry.py
Adds NativeEmbeddingOp with shared gather logic and forward / forward_fp32 dtype paths, adds OpBackend.PYTORCH_NATIVE_EMBEDDING, and routes "embedding" dispatch on cuda, rocm, and cpu to that backend.
Embedding behavior tests
tests/test_embedding.py
Covers bitwise gather correctness, shape and dtype rules, token id casting, batch and padding invariance, immutability, gradient flow, registry lookup, and a CUDA-only real-shape smoke test gated by free GPU memory.
Operator docs and navigation
docs/.nav.yml, docs/operators/README.md, docs/operators/embedding.md
Adds the embedding.md operator page, links it from docs navigation and the operator index, and documents the operator contract, dispatch behavior, accuracy notes, and limitations.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐇 I hop through tokens, row by row,
Gathered just right, with a tidy glow.
Docs and tests now lead the way,
While registry bells softly say: yay!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: a pure-PyTorch NativeEmbeddingOp reference plus numerical contract tests.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

*,
output_dtype: torch.dtype,
) -> torch.Tensor:
out = F.embedding(token_ids.long(), weight.float())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m worried about this weight.float() path: it upcasts the entire embedding table before gathering, so real fp16/bf16 Qwen3-size weights will allocate an extra multi-GB fp32 copy for a tiny lookup. Since this is the only registered embedding backend today, this could make the default fallback OOM in normal GPU use.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inaniloquentee Thanks for the advice!
The gather now runs in the weight's native dtype and only the gathered rows are upcast (F.embedding(token_ids.long(), weight).to(output_dtype)), so there's no longer a multi-GB fp32 copy of the full vocab table for a tiny lookup. Since a row gather is lossless (pure indexing, no arithmetic), this is bitwise-identical to the previous path — all 11 tests in tests/test_embedding.py still pass.

Gathering with weight.float() upcast the entire vocab table to fp32 before
the lookup, allocating a multi-GB fp32 copy of the Qwen3-8B embedding table
just for a tiny row gather and risking OOM on the default fallback path.

A row gather is lossless (pure indexing, no arithmetic), so gather in the
weight's native dtype and upcast only the gathered rows -- bitwise-identical
to the previous path. All 11 tests in tests/test_embedding.py still pass.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall good to me. After fix the requests, I am happy to approve.

Comment thread tests/test_embedding.py Outdated
Comment on lines +119 to +129
def test_embedding_gradient_flows_to_weight():
op = NativeEmbeddingOp()
token_ids = _rand_ids((2, 4), seed=7, vocab=10) # small vocab -> some unused rows
weight = _rand_weight(vocab=10, seed=7).requires_grad_(True)
op.forward_fp32(token_ids, weight).sum().backward()

assert torch.isfinite(weight.grad).all()
used = torch.unique(token_ids).tolist()
unused = torch.tensor([i for i in range(10) if i not in used])
if len(unused):
assert torch.equal(weight.grad[unused], torch.zeros_like(weight.grad[unused]))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reduction in the backward pass: ∂L/∂weight is a scatter-add, and every repeated token id (padding, common tokens, anything) accumulates into the same weight.grad row. On CUDA, F.embedding backward does that accumulation with atomic adds, and atomic-add order is nondeterministic, so the weight gradient is not bitwise reproducible on GPU whenever ids collide. PyTorch documents embedding backward as a nondeterministic CUDA op for exactly this reason.

Either set it in the op's golden path / a fixture, or state explicitly in embedding.md under Known Limitations that the GPU backward is bitwise-reproducible only with deterministic algorithms enabled.

# forward_fp32 is the backward golden source — its gradient must be
# reproducible. F.embedding backward is an atomic scatter-add on CUDA
# (nondeterministic under id collisions) unless this is set.
torch.use_deterministic_algorithms(True)

…nism

Backward (∂L/∂weight) is a scatter-add; repeated token ids accumulate into
the same grad row, and on CUDA that uses atomic adds (nondeterministic order),
so the weight gradient is not bitwise reproducible when ids collide. Since
forward_fp32 is the backward golden source:

- docs: document the limitation under Known Limitations (embedding.md)
- test: enable torch.use_deterministic_algorithms(True) in the gradient test
  and assert grad is bitwise identical across two independent backward passes
@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

Overall good to me. After fix the requests, I am happy to approve.
@KJLdefeated Good catch — you're right that forward_fp32 is the backward golden source, so its gradient has to be reproducible, and embedding backward is a nondeterministic atomic scatter-add on CUDA whenever token ids collide (padding / common tokens accumulate into the same weight.grad row). Addressed in fab4da6 with both parts you suggested:

  1. Test guard — test_embedding_gradient_flows_to_weight now wraps the backward in torch.use_deterministic_algorithms(True) (restored via try/finally) and asserts the weight gradient is bitwise identical across two independent backward passes — so the golden-source reproducibility claim is actually enforced, not just isfinite. I kept the flag in the test/fixture rather than the op's forward, since toggling that global is a process-wide side effect a reference op shouldn't own.
  2. Docs — added a Known Limitations bullet in embedding.md spelling out the scatter-add / CUDA atomic-add nondeterminism and that reproducible GPU gradients require torch.use_deterministic_algorithms(True), with CPU backward always deterministic.

All 11 tests in tests/test_embedding.py still pass. Thanks for the careful review!

@KJLdefeated

Copy link
Copy Markdown
Collaborator

LGTM.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall good to me. After fix the requests, I am happy to approve.

@Flink-ddd

Copy link
Copy Markdown
Collaborator

Please resolve CI error and the code conflicts first. Thanks.

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR elegantly resolves both the OOM hazard (by deferring the upcast) and the non-deterministic scatter-add edge case on CUDA. The implementation and the test coverage are rock solid.

Since the major functional edge cases have been cleared by the team, there is only one architectural standardization left to align this with the rest of the WS1 reference suite.

is *independent* from the lm_head weight (``tie_word_embeddings=false``).
"""

def __init__(self) -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, Add torch.nn.Module to the class inheritance and initialize super().init().
Delete the def call(self, ...):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants