From 65ca0d6519494ab1336bcee2642e9ae3d8da2181 Mon Sep 17 00:00:00 2001 From: tubo213 Date: Thu, 5 Feb 2026 19:56:47 +0900 Subject: [PATCH 1/3] Add Context Parallel (CP) support for RMSNorm This enables RMSNorm to work efficiently with DTensor inputs that are sharded on the sequence dimension (Context Parallel), in addition to the existing Tensor Parallel support. Key changes: - Add _is_hidden_dim_sharded() helper to detect TP vs CP sharding - For CP inputs, compute locally without full_tensor() gathering - All-reduce dW gradient in backward for CP to aggregate across devices - Add try-except for Shard import for older PyTorch compatibility - Add tests for Context Parallel DTensor inputs --- src/liger_kernel/ops/rms_norm.py | 97 +++++++++++++++++++++++++++--- test/transformers/test_rms_norm.py | 97 ++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 9 deletions(-) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index e5cab72ea..268c77a05 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -17,6 +17,14 @@ import triton import triton.language as tl +try: + from torch.distributed.tensor import Shard + + _DTENSOR_AVAILABLE = True +except ImportError: + _DTENSOR_AVAILABLE = False + Shard = None + from liger_kernel.ops.utils import calculate_settings from liger_kernel.ops.utils import compare_version from liger_kernel.ops.utils import ensure_contiguous @@ -25,6 +33,30 @@ from liger_kernel.ops.utils import torch_to_triton_dtype from liger_kernel.utils import is_npu_available + +def _is_hidden_dim_sharded(dtensor: "torch.distributed.tensor.DTensor") -> bool: + """ + Check if the DTensor is sharded on the hidden dimension (last dimension). + + This is used to determine whether we need to gather the full tensor for RMSNorm + computation (Tensor Parallel case) or can compute locally (Context Parallel case). + + Args: + dtensor: A DTensor instance to check. + + Returns: + True if the tensor is sharded on the hidden (last) dimension (TP case), + False otherwise (CP case - can compute locally). + """ + if not _DTENSOR_AVAILABLE or Shard is None: + return False + hidden_dim = dtensor.ndim - 1 # Last dimension is the hidden dimension + for placement in dtensor.placements: + if isinstance(placement, Shard) and placement.dim == hidden_dim: + return True + return False + + if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): try: # typical import path with dispatch available @@ -609,12 +641,25 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row X: (B, T, H) or (BxT, H) W: (H,) """ + # Track DTensor metadata for potential reconstruction in backward + ctx.is_dtensor_input = False + ctx.dtensor_device_mesh = None + ctx.dtensor_placements = None + if isinstance(X, torch.distributed.tensor.DTensor): - # Input tensor is output of a tensor parallel module and - # needs to be gathered to a local tensor to compute - # RMSE layer norm on each TP worker. - # TODO: support CP. - X = X.full_tensor() + if _is_hidden_dim_sharded(X): + # Tensor Parallel (TP): hidden dimension is sharded across devices. + # RMSNorm requires the full hidden dimension to compute the RMS, + # so we need to gather the full tensor. + X = X.full_tensor() + else: + # Context Parallel (CP): sequence dimension is sharded. + # RMSNorm computes independently for each sequence position, + # so we can compute locally without gathering. + ctx.is_dtensor_input = True + ctx.dtensor_device_mesh = X.device_mesh + ctx.dtensor_placements = X.placements + X = X.to_local() Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode) ctx.offset = offset @@ -628,6 +673,15 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row ctx.save_for_backward(X, W, RSTD) else: ctx.save_for_backward(X, RSTD) + + # If input was a CP DTensor, wrap output back into DTensor + if ctx.is_dtensor_input: + Y = torch.distributed.tensor.DTensor.from_local( + Y, + device_mesh=ctx.dtensor_device_mesh, + placements=ctx.dtensor_placements, + ) + return Y @staticmethod @@ -643,12 +697,37 @@ def backward(ctx, dY): W = None if isinstance(dY, torch.distributed.tensor.DTensor): - # Gradients are output of a tensor parallel module and - # needs to be gathered to a local tensor for computing RMSE layer. - # TODO: support CP. - dY = dY.full_tensor() + if ctx.is_dtensor_input: + # Context Parallel (CP): sequence dimension is sharded. + # We can compute gradients locally for each sequence position. + dY = dY.to_local() + else: + # Tensor Parallel (TP): hidden dimension is sharded. + # Need to gather the full gradient tensor. + dY = dY.full_tensor() dX, dW = rms_norm_backward( dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode ) + + # If input was a CP DTensor, handle output accordingly + if ctx.is_dtensor_input: + # Wrap dX back into DTensor with the same placements + dX = torch.distributed.tensor.DTensor.from_local( + dX, + device_mesh=ctx.dtensor_device_mesh, + placements=ctx.dtensor_placements, + ) + + # For dW, we need to all-reduce across the CP process group + # since each device only computed gradients for its local sequence positions, + # but the weight is shared across all positions. + if dW is not None and _DTENSOR_AVAILABLE and Shard is not None: + for i, placement in enumerate(ctx.dtensor_placements): + if isinstance(placement, Shard): + # Get the process group for this mesh dimension + pg = ctx.dtensor_device_mesh.get_group(mesh_dim=i) + torch.distributed.all_reduce(dW, group=pg) + break + return dX, dW, None, None, None, None, None diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index 8c36472bb..371e589ea 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -309,3 +309,100 @@ def test_dtensor_rms_norm(world_size, bs, sl, hd, dtype, atol, rtol, offset, cas nprocs=world_size, join=True, ) + + +def _test_dtensor_rms_norm_context_parallel( + rank, world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode, file_name +): + """ + Test RMSNorm with Context Parallel (CP) - sequence dimension sharding. + + Unlike Tensor Parallel (TP) which shards on hidden dimension, CP shards on + sequence dimension. RMSNorm can compute locally for CP since each position + is independent, avoiding the need for full_tensor() gathering. + """ + torch.distributed.init_process_group( + backend=infer_comm_backend(), + init_method=f"file://{file_name}", + rank=rank, + world_size=world_size, + ) + device = f"{infer_device()}:{rank}" if infer_device() != "cpu" else "cpu" + device_mesh = torch.distributed.device_mesh.init_device_mesh( + infer_device(), mesh_shape=(world_size,), mesh_dim_names=("cp",) + ) + + # Create a tensor and shard on sequence dimension (dim=1) for CP + # sl must be divisible by world_size for even sharding + t = torch.randn(bs, sl, hd, device=device, dtype=dtype, requires_grad=True) + dt = torch.distributed.tensor.distribute_tensor( + t, + device_mesh=device_mesh, + placements=[torch.distributed.tensor.Shard(1)], # Shard on sequence dim for CP + ) + + # Weight is replicated across all devices + w = torch.randn(hd, device=device, dtype=dtype, requires_grad=True) + w1 = w.detach().clone().requires_grad_(True) + w2 = w.detach().clone().requires_grad_(True) + + # Forward pass: compare DTensor (CP) result with regular tensor result + y1 = liger_rms_norm(X=dt, W=w1, eps=1e-6, offset=offset, casting_mode=casting_mode) + y2 = liger_rms_norm(X=t, W=w2, eps=1e-6, offset=offset, casting_mode=casting_mode) + + # y1 is a DTensor sharded on sequence dim, y2 is a regular tensor + # Compare the full tensors + torch.testing.assert_close(y1.full_tensor(), y2, atol=atol, rtol=rtol) + + # Backward pass + grad = torch.randn_like(y2) + dgrad = torch.distributed.tensor.distribute_tensor( + grad, + device_mesh=device_mesh, + placements=[torch.distributed.tensor.Shard(1)], # Same sharding as output + ) + + y1.backward(dgrad) + y2.backward(grad) + + # Check weight gradients: should match after all-reduce in backward + torch.testing.assert_close(w1.grad, w2.grad, atol=atol, rtol=rtol) + + # Check input gradients: dt.grad is a DTensor, t.grad is a regular tensor + torch.testing.assert_close(dt.grad.full_tensor(), t.grad, atol=atol, rtol=rtol) + + +@pytest.mark.xfail( + torch.cuda.device_count() < 4, + reason="Pending multi-GPU host support. This test requires at least 4 GPUs.", +) +@pytest.mark.parametrize( + "world_size, bs, sl, hd", + [ + (2, 2, 8, 16), # sl=8 divisible by world_size=2 + (4, 2, 16, 32), # sl=16 divisible by world_size=4 + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-4, 1e-6), + (torch.bfloat16, 2e-1, 2e-2), + ], +) +@pytest.mark.parametrize( + "offset, casting_mode", + [ + (0.0, "llama"), + (1.0, "gemma"), + ], +) +def test_dtensor_rms_norm_context_parallel(world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode): + """Test RMSNorm with Context Parallel (sequence dimension sharding).""" + with tempfile.NamedTemporaryFile() as f: + mp.spawn( + _test_dtensor_rms_norm_context_parallel, + args=(world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode, f.name), + nprocs=world_size, + join=True, + ) From 833d1e0be7e99c75a2e620db4af00ddfd5b631f9 Mon Sep 17 00:00:00 2001 From: tubo213 Date: Thu, 5 Feb 2026 23:03:55 +0900 Subject: [PATCH 2/3] Improve CP RMSNorm test coverage and cleanup Add non-power-of-2 test dimensions (batch=3, seq=6, hidden=17) to catch edge cases, and add proper process group cleanup with destroy_process_group(). --- test/transformers/test_rms_norm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index 371e589ea..12a826650 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -371,6 +371,8 @@ def _test_dtensor_rms_norm_context_parallel( # Check input gradients: dt.grad is a DTensor, t.grad is a regular tensor torch.testing.assert_close(dt.grad.full_tensor(), t.grad, atol=atol, rtol=rtol) + torch.distributed.destroy_process_group() + @pytest.mark.xfail( torch.cuda.device_count() < 4, @@ -381,6 +383,7 @@ def _test_dtensor_rms_norm_context_parallel( [ (2, 2, 8, 16), # sl=8 divisible by world_size=2 (4, 2, 16, 32), # sl=16 divisible by world_size=4 + (2, 3, 6, 17), # weird shapes: non-power-of-2 batch, seq, hidden dims ], ) @pytest.mark.parametrize( From 0e0c5cd2278dba2d67fe6b9aa3e94ddb83eace99 Mon Sep 17 00:00:00 2001 From: tubo213 Date: Thu, 5 Feb 2026 23:11:40 +0900 Subject: [PATCH 3/3] Fix dW all-reduce to handle multi-dimensional meshes Remove the `break` statement so that dW is all-reduced across all sharded mesh dimensions, not just the first one. This ensures correct weight gradients when using multi-dimensional meshes (e.g., batch + sequence sharding). --- src/liger_kernel/ops/rms_norm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index 268c77a05..3ea3c74f7 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -719,15 +719,14 @@ def backward(ctx, dY): placements=ctx.dtensor_placements, ) - # For dW, we need to all-reduce across the CP process group + # For dW, we need to all-reduce across all sharded mesh dimensions # since each device only computed gradients for its local sequence positions, - # but the weight is shared across all positions. + # but the weight is shared across all positions. For multi-dimensional meshes + # (e.g., batch + sequence sharding), we must reduce across each sharded dim. if dW is not None and _DTENSOR_AVAILABLE and Shard is not None: for i, placement in enumerate(ctx.dtensor_placements): if isinstance(placement, Shard): - # Get the process group for this mesh dimension pg = ctx.dtensor_device_mesh.get_group(mesh_dim=i) torch.distributed.all_reduce(dW, group=pg) - break return dX, dW, None, None, None, None, None