perf: extend RMSNorm block-row dispatch to hidden<=512, lower row threshold#1238
Open
alifurkanstahl wants to merge 6 commits into
Open
perf: extend RMSNorm block-row dispatch to hidden<=512, lower row threshold#1238alifurkanstahl wants to merge 6 commits into
alifurkanstahl wants to merge 6 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Extend the RMSNorm block-row dispatch path to cover
hidden_size <= 512and lower the row count at which it engages.Previously the block-row forward/backward kernels were only selected when
BLOCK_SIZE <= 256andn_rows >= 32768(4096 * 8); everything else used the single-row path. This PR:BLOCK_SIZE <= 512(_BLOCK_ROW_MAX_BLOCK_SIZE = 512), sohidden_size = 512now uses the block-row kernels;n_rows >= 4096(_BLOCK_ROW_MIN_ROWS = 4096);256,4096 * 8) in bothrms_norm_forwardandrms_norm_backwardwith two named constants.This is a dispatch-only change: the block-row kernels themselves are unchanged. The goal is to expose the existing path to small-hidden, many-row training workloads where the current single-row path leaves occupancy/launch efficiency on the table.
Scope: this targets the small-
hidden(≤ 512), many-row regime — small models and sub-modules. Mainstream LLM hidden sizes (≥ 2048) are unaffected and keep the existing single-row path.Details
BLOCK_SIZE/n_rows/row_modebased — not device-gated. The dispatch is intentionally not CUDA-gated because the existing block-row backward already has backend-specific sm/core-count handling (cuda/xpu/npu). Perf was measured on SM120 only (RTX 5060 Ti); CI should validate correctness on the other supported backends, and backend-specific perf can be retuned if needed. Happy to make the lower threshold CUDA-only if maintainers prefer avoiding heuristic changes on non-CUDA backends before backend-specific perf data is available.dWpartial accumulators remain fp32, so the numerical precision path is unchanged; correctness is covered by the existingallclosechecks across fp32/bf16 and the Llama/Gemma/Base RMSNorm variants.4096is the lowest row count where every pass wins cleanly.Benchmark
NVIDIA GeForce RTX 5060 Ti · bfloat16 · Triton
do_benchmedian, 7 repeats ·BLOCK_ROW = 16. Each cell = single-row ms → block-row ms (delta%), negative = faster.h = 768(BLOCK_SIZE = 1024 > 512) is a control that stays on the single-row path.Full forward + backward:
Backward only (the dominant contributor):
Forward-only is small and shrinks near the ceiling (h=512: -8.1% at 4096 rows → -1.8% at 32768) — not a headline. The control column confirms the single-row path is untouched.
Testing Done
Added
test_correctness_block_rowtotest/transformers/test_rms_norm.py. The existingtest_correctnessshapes never reachn_rows >= 4096, so the block-row kernels previously had no coverage. The new test forces shapes onto the block-row path and cross-checks against Llama/Gemma/Base references over fp32 and bf16, within_place ∈ {True, False}andelementwise_affine ∈ {True, False}. It includes a non-power-of-2 hidden size (exercises the column mask) and a row count not divisible byBLOCK_ROW(exercises the row-tail mask), plus guard asserts (triton.next_power_of_2(hd)vs the dispatch constants) so the test fails loudly rather than silently going no-op if the thresholds are ever retuned.test/transformers/test_rms_norm.py: 144 passed, 8 xfailed (the 8 xfails are pre-existingcasting_mode="none"bf16 large-row cases, unrelated to this change).Note: the full
make testrun also surfaces one unrelated, pre-existing failure —test_liger_group_norm[...-16-48-12-8192]— a tight fp32 tolerance in the GroupNorm kernel at hidden=8192. This PR does not touch GroupNorm (the diff is limited torms_norm.py), and the failure reproduces on a cleanmaincheckout (decb1b7) on this GPU.Test logs
make test— RMSNorm passes; one unrelated pre-existing GroupNorm failure reproduced on cleanmainmake checkstyleto ensure code stylemake test-convergenceto ensure convergence