Skip to content

perf: extend RMSNorm block-row dispatch to hidden<=512, lower row threshold#1238

Open
alifurkanstahl wants to merge 6 commits into
linkedin:mainfrom
alifurkanstahl:rmsnorm-block-row-hidden-512
Open

perf: extend RMSNorm block-row dispatch to hidden<=512, lower row threshold#1238
alifurkanstahl wants to merge 6 commits into
linkedin:mainfrom
alifurkanstahl:rmsnorm-block-row-hidden-512

Conversation

@alifurkanstahl

@alifurkanstahl alifurkanstahl commented May 25, 2026

Copy link
Copy Markdown

Summary

Extend the RMSNorm block-row dispatch path to cover hidden_size <= 512 and lower the row count at which it engages.

Previously the block-row forward/backward kernels were only selected when BLOCK_SIZE <= 256 and n_rows >= 32768 (4096 * 8); everything else used the single-row path. This PR:

  • raises the ceiling to BLOCK_SIZE <= 512 (_BLOCK_ROW_MAX_BLOCK_SIZE = 512), so hidden_size = 512 now uses the block-row kernels;
  • lowers the row threshold to n_rows >= 4096 (_BLOCK_ROW_MIN_ROWS = 4096);
  • replaces the inline magic numbers (256, 4096 * 8) in both rms_norm_forward and rms_norm_backward with two named constants.

This is a dispatch-only change: the block-row kernels themselves are unchanged. The goal is to expose the existing path to small-hidden, many-row training workloads where the current single-row path leaves occupancy/launch efficiency on the table.

Scope: this targets the small-hidden (≤ 512), many-row regime — small models and sub-modules. Mainstream LLM hidden sizes (≥ 2048) are unaffected and keep the existing single-row path.

Details

  • The dispatch guard is purely BLOCK_SIZE / n_rows / row_mode based — not device-gated. The dispatch is intentionally not CUDA-gated because the existing block-row backward already has backend-specific sm/core-count handling (cuda / xpu / npu). Perf was measured on SM120 only (RTX 5060 Ti); CI should validate correctness on the other supported backends, and backend-specific perf can be retuned if needed. Happy to make the lower threshold CUDA-only if maintainers prefer avoiding heuristic changes on non-CUDA backends before backend-specific perf data is available.
  • The dW partial accumulators remain fp32, so the numerical precision path is unchanged; correctness is covered by the existing allclose checks across fp32/bf16 and the Llama/Gemma/Base RMSNorm variants.
  • The full-pass win is driven by the backward pass (−51 … −54% at h=512); forward-only is marginal and shrinks toward the ceiling, so this should not be read as a general inference/forward speedup. 4096 is the lowest row count where every pass wins cleanly.

Benchmark

NVIDIA GeForce RTX 5060 Ti · bfloat16 · Triton do_bench median, 7 repeats · BLOCK_ROW = 16. Each cell = single-row ms → block-row ms (delta%), negative = faster. h = 768 (BLOCK_SIZE = 1024 > 512) is a control that stays on the single-row path.

Full forward + backward:

rows h=64 h=128 h=256 h=512 h=768 (ctrl)
4096 0.075→0.024 (-67.9%) 0.080→0.034 (-56.7%) 0.088→0.047 (-46.9%) 0.110→0.059 (-46.5%) 0.131→0.131 (+0.0%)
8192 0.133→0.034 (-74.1%) 0.141→0.051 (-63.7%) 0.159→0.077 (-52.0%) 0.198→0.106 (-46.5%) 0.247→0.247 (-0.0%)
16384 0.248→0.051 (-79.4%) 0.264→0.088 (-66.8%) 0.293→0.131 (-55.3%) 0.387→0.233 (-39.7%) 0.485→0.484 (-0.2%)
32768 0.477→0.086 (-82.0%) 0.506→0.158 (-68.7%) 0.584→0.258 (-55.9%) 0.779→0.452 (-42.0%) 0.946→0.945 (-0.1%)

Backward only (the dominant contributor):

rows h=64 h=128 h=256 h=512 h=768 (ctrl)
4096 0.069→0.021 (-69.8%) 0.071→0.030 (-57.4%) 0.080→0.037 (-54.0%) 0.096→0.047 (-51.2%) 0.105→0.104 (-0.9%)
8192 0.122→0.028 (-76.8%) 0.131→0.047 (-64.2%) 0.145→0.061 (-58.1%) 0.168→0.080 (-52.7%) 0.188→0.188 (+0.0%)
16384 0.229→0.045 (-80.5%) 0.243→0.079 (-67.4%) 0.265→0.100 (-62.2%) 0.315→0.145 (-53.9%) 0.353→0.354 (+0.2%)
32768 0.438→0.073 (-83.3%) 0.462→0.133 (-71.3%) 0.508→0.172 (-66.2%) 0.602→0.274 (-54.4%) 0.688→0.688 (-0.0%)

Forward-only is small and shrinks near the ceiling (h=512: -8.1% at 4096 rows → -1.8% at 32768) — not a headline. The control column confirms the single-row path is untouched.

Testing Done

Added test_correctness_block_row to test/transformers/test_rms_norm.py. The existing test_correctness shapes never reach n_rows >= 4096, so the block-row kernels previously had no coverage. The new test forces shapes onto the block-row path and cross-checks against Llama/Gemma/Base references over fp32 and bf16, with in_place ∈ {True, False} and elementwise_affine ∈ {True, False}. It includes a non-power-of-2 hidden size (exercises the column mask) and a row count not divisible by BLOCK_ROW (exercises the row-tail mask), plus guard asserts (triton.next_power_of_2(hd) vs the dispatch constants) so the test fails loudly rather than silently going no-op if the thresholds are ever retuned.

test/transformers/test_rms_norm.py: 144 passed, 8 xfailed (the 8 xfails are pre-existing casting_mode="none" bf16 large-row cases, unrelated to this change).

Note: the full make test run also surfaces one unrelated, pre-existing failure — test_liger_group_norm[...-16-48-12-8192] — a tight fp32 tolerance in the GroupNorm kernel at hidden=8192. This PR does not touch GroupNorm (the diff is limited to rms_norm.py), and the failure reproduces on a clean main checkout (decb1b7) on this GPU.

Test logs
$ python -m pytest test/transformers/test_rms_norm.py
============ 144 passed, 8 xfailed, 2 warnings in 86.63s (0:01:26) =============

$ make checkstyle
All checks passed!
273 files already formatted
All checks passed!
273 files left unchanged
  • Hardware Type: NVIDIA GeForce RTX 5060 Ti
  • run make test — RMSNorm passes; one unrelated pre-existing GroupNorm failure reproduced on clean main
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@alifurkanstahl alifurkanstahl changed the title perf: extend RMSNorm block-row dispatch to hidden≤512, lower row threshold to 4096 perf: extend RMSNorm block-row dispatch to hidden<=512, lower row threshold May 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant