Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
# Min flattened rows to dispatch to the block-row path; lowered from the original
# 32768 (4096*8). At >=4096 rows the full fwd+bwd pass is faster on the block-row
# kernels for hidden<=512; forward-only can be marginally slower near the boundary.
_BLOCK_ROW_MIN_ROWS = 4096
# Max BLOCK_SIZE (i.e. hidden size) for the block-row path; raised from 256 to 512 since
# block-row also wins at hidden=512. Larger hidden sizes stay on the single-row path.
_BLOCK_ROW_MAX_BLOCK_SIZE = 512


@triton.jit
Expand Down Expand Up @@ -438,7 +445,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
if BLOCK_SIZE > _BLOCK_ROW_MAX_BLOCK_SIZE or n_rows < _BLOCK_ROW_MIN_ROWS or row_mode:
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
Expand Down Expand Up @@ -519,7 +526,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)

if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
if BLOCK_SIZE > _BLOCK_ROW_MAX_BLOCK_SIZE or n_rows < _BLOCK_ROW_MIN_ROWS or row_mode:
_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
Expand Down
95 changes: 95 additions & 0 deletions test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import triton

from test.utils import assert_verbose_allclose
from test.utils import set_seed
Expand Down Expand Up @@ -184,6 +185,100 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m
assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20)


@pytest.mark.flaky(reruns=3, reruns_delay=2)
# These shapes have bs*sl >= _BLOCK_ROW_MIN_ROWS with a hidden size whose BLOCK_SIZE
# is <= _BLOCK_ROW_MAX_BLOCK_SIZE, which is what routes RMSNorm to the block-row
# forward/backward kernels. The standard test_correctness shapes never reach that
# row count, so without this the block-row kernels have no CI coverage.
@pytest.mark.parametrize(
"bs, sl, hd",
[
(16, 512, 512), # hidden=512, BLOCK_SIZE=512 (upper edge of the path)
(32, 256, 256),
(16, 512, 128), # small hidden
# non-power-of-2 hidden (BLOCK_SIZE=512 -> exercises col_mask) and n_rows
# not divisible by BLOCK_ROW=16 (4097 -> exercises the row-tail mask)
(1, 4097, 384),
],
)
# casting_mode="none" keeps the whole reduction in the input dtype; in bf16 over
# this many rows it loses too much precision on BOTH dispatch paths, so it is only
# exercised in fp32 here.
@pytest.mark.parametrize(
"reference, offset, casting_mode, dtype, atol, rtol",
[
(LlamaRMSNorm, 0.0, "llama", torch.float32, 1e-4, 1e-6),
(GemmaRMSNorm, 1.0, "gemma", torch.float32, 1e-4, 1e-6),
pytest.param(
BaseRMSNorm,
0.0,
"none",
torch.float32,
1e-4,
1e-6,
marks=pytest.mark.skipif(device == "npu", reason="Ascend NPU does not support this test"),
),
pytest.param(
LlamaRMSNorm,
0.0,
"llama",
torch.bfloat16,
2e-1,
2e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
pytest.param(
GemmaRMSNorm,
1.0,
"gemma",
torch.bfloat16,
2e-1,
2e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
],
)
@pytest.mark.parametrize("in_place", [True, False])
@pytest.mark.parametrize("elementwise_affine", [True, False])
def test_correctness_block_row(
bs, sl, hd, reference, offset, casting_mode, dtype, atol, rtol, in_place, elementwise_affine
):
from liger_kernel.ops import rms_norm as rms_norm_ops

# Guard so this test stays meaningful if the dispatch thresholds are ever retuned.
block_size = triton.next_power_of_2(hd)
assert bs * sl >= rms_norm_ops._BLOCK_ROW_MIN_ROWS, "shape no longer triggers the block-row path"
assert block_size <= rms_norm_ops._BLOCK_ROW_MAX_BLOCK_SIZE, "hidden size no longer uses the block-row path"

_tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)
h1 = _tensor.clone().requires_grad_(True)
h2 = _tensor.clone().requires_grad_(True)
do = torch.randn(bs, sl, hd, device=device, dtype=dtype)

ref_rms = reference(hidden_size=hd, elementwise_affine=elementwise_affine).to(device).to(dtype)
ref_o = ref_rms(h1)
ref_o.backward(do, retain_graph=True)

triton_rms = (
LigerRMSNorm(
hidden_size=hd,
offset=offset,
casting_mode=casting_mode,
in_place=in_place,
elementwise_affine=elementwise_affine,
)
.to(device)
.to(dtype)
)
triton_o = triton_rms(h2)
triton_o.backward(do, retain_graph=True)

assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol)
assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20)
if elementwise_affine:
assert_verbose_allclose(ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"bs, sl, hd",
[
Expand Down