diff --git a/tests/pytorch/test_mxfp8_2d_quantize.py b/tests/pytorch/test_mxfp8_2d_quantize.py new file mode 100644 index 0000000000..0f70b10b50 --- /dev/null +++ b/tests/pytorch/test_mxfp8_2d_quantize.py @@ -0,0 +1,444 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Unit tests for MXFP8 2D block scaling quantization. +MXFP8 2D scaling: 32x32 blocks share a single scaling factor, rowwise and colwise scales are identical. +""" + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer +from transformer_engine.common.recipe import MXFP8BlockScaling, QParams + + +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) + +# MXFP8 constants +MXFP8_BLOCK_SIZE = 32 +FP8_E4M3_MAX = 448.0 + + +def float_to_e8m0(amax: torch.Tensor) -> torch.Tensor: + """ + Convert absolute maximum values to E8M0 biased exponent (scale inverse). + + This mimics the GPU implementation in ptx::float_to_e8m0: + 1. Compute val = amax / FP8_MAX (same as amax * max_norm_rcp) + 2. Extract the biased exponent from the IEEE754 FP32 representation + 3. Round up if there's any mantissa (ceil behavior) + + E8M0 format: 8-bit unsigned integer representing 2^(value - 127) + """ + # Compute val = amax / FP8_MAX (same as GPU: amax * max_norm_rcp) + val = amax.to(torch.float32) / FP8_E4M3_MAX + + # Reinterpret float32 bits as int32 + val_u32 = val.view(torch.int32) + + # Extract biased exponent (bits 30:23) - GPU does: (val_u32 >> 23) and truncates to uint8 + exponent = ((val_u32 >> 23) & 0xFF).to(torch.int32) + + # Extract mantissa (bits 22:0) + mantissa = val_u32 & 0x7FFFFF + + # Round up condition from GPU: + # if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) + round_up = (mantissa > 0) & (exponent != 254) & ~((exponent == 0) & (mantissa <= 0x400000)) + exponent = exponent + round_up.to(torch.int32) + + # Handle special cases (GPU handles these before the main logic) + # val == 0 -> return 0 + exponent = torch.where(val == 0, torch.zeros_like(exponent), exponent) + + return exponent.to(torch.uint8) + + +def e8m0_to_scale_inv(e8m0: torch.Tensor) -> torch.Tensor: + """Convert E8M0 biased exponent back to scale inverse (float).""" + return torch.pow(2.0, e8m0.to(torch.float32) - 127) + + +def quantize_mxfp8_2d_reference( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reference implementation of MXFP8 2D block scaling quantization. + + For 2D scaling, each 32x32 block shares a single E8M0 scale factor. + + Args: + x: Input tensor of shape (M, N), assumes M and N are multiples of 32 + + Returns: + qx_rowwise: Quantized data in row-major order + scale_rowwise: E8M0 scale inverses for rowwise (shape: M x ceil(N/32)) + qx_colwise: Quantized data in column-major order + scale_colwise: E8M0 scale inverses for colwise (shape: ceil(M/32) x N) + """ + M, N = x.shape + device = x.device + dtype = x.dtype + + # Pad to multiples of 32 if needed + pad_M = (MXFP8_BLOCK_SIZE - M % MXFP8_BLOCK_SIZE) % MXFP8_BLOCK_SIZE + pad_N = (MXFP8_BLOCK_SIZE - N % MXFP8_BLOCK_SIZE) % MXFP8_BLOCK_SIZE + if pad_M > 0 or pad_N > 0: + x = torch.nn.functional.pad(x, (0, pad_N, 0, pad_M), mode="constant", value=0.0) + + M_padded, N_padded = x.shape + num_block_rows = M_padded // MXFP8_BLOCK_SIZE + num_block_cols = N_padded // MXFP8_BLOCK_SIZE + + # Reshape to expose 32x32 blocks + x_blocks = x.view(num_block_rows, MXFP8_BLOCK_SIZE, num_block_cols, MXFP8_BLOCK_SIZE).permute( + 0, 2, 1, 3 + ) # (num_block_rows, num_block_cols, 32, 32) + + # Compute amax for each 32x32 block + block_amax = torch.amax( + torch.abs(x_blocks.to(torch.float32)), dim=(-1, -2) + ) # (num_block_rows, num_block_cols) + + # Convert to E8M0 scale inverse + block_scale_e8m0 = float_to_e8m0(block_amax) # (num_block_rows, num_block_cols) + block_scale_inv = e8m0_to_scale_inv(block_scale_e8m0) # (num_block_rows, num_block_cols) + + # Expand scale to match input dimensions for quantization + # For rowwise: each row in a block uses the same scale, scale shape is (M, num_block_cols) + scale_rowwise = block_scale_e8m0.repeat_interleave( + MXFP8_BLOCK_SIZE, dim=0 + ) # (M_padded, num_block_cols) + + # For colwise: each column in a block uses the same scale, scale shape is (num_block_rows, N) + scale_colwise = block_scale_e8m0.repeat_interleave( + MXFP8_BLOCK_SIZE, dim=1 + ) # (num_block_rows, N_padded) + + # Compute scale inverse for quantization (broadcast over 32x32 blocks) + scale_inv_expanded = block_scale_inv.unsqueeze(-1).unsqueeze( + -1 + ) # (num_block_rows, num_block_cols, 1, 1) + scale_inv_expanded = scale_inv_expanded.expand(-1, -1, MXFP8_BLOCK_SIZE, MXFP8_BLOCK_SIZE) + + # Quantize: x_quantized = round(x / scale_inv) clamped to FP8 range + x_blocks_float = x_blocks.to(torch.float32) + x_scaled = x_blocks_float / scale_inv_expanded + + # Convert to FP8 (using PyTorch's float8_e4m3fn) + x_quantized = x_scaled.to(torch.float8_e4m3fn) + + # Reshape back to original layout + # Rowwise: (M_padded, N_padded) + qx_rowwise = x_quantized.permute(0, 2, 1, 3).reshape(M_padded, N_padded) + + # Colwise: same data but transposed for column-major access + qx_colwise = x_quantized.permute(0, 2, 1, 3).reshape(M_padded, N_padded) + + # Remove padding from outputs + qx_rowwise = qx_rowwise[:M, :N] + qx_colwise = qx_colwise[:M, :N] + scale_rowwise = scale_rowwise[:M, :] + scale_colwise = scale_colwise[:, :N] + + return qx_rowwise, scale_rowwise, qx_colwise, scale_colwise + + +def check_mxfp8_2d_quantization_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + use_cpp_allocator: bool, +) -> None: + """ + Test MXFP8 2D quantization against CPU reference implementation. + + Verifies: + 1. scales match reference + 2. 32x32 blocks share the same scale + 3. rowwise and colwise quantized data match reference + """ + fp8_dtype = tex.DType.kFloat8E4M3 + + device = "cuda" + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create input tensor + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # GPU Quantization using MXFP8Quantizer with 2D scaling + quantizer = MXFP8Quantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + with_2d_quantization=True, + ) + + if use_cpp_allocator: + x_mxfp8 = quantizer(x) + else: + x_mxfp8 = quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_mxfp8 = quantizer.update_quantized(x, x_mxfp8) + + # Extract GPU results + assert x_mxfp8._rowwise_data is not None + assert x_mxfp8._columnwise_data is not None + assert x_mxfp8._rowwise_scale_inv is not None + assert x_mxfp8._columnwise_scale_inv is not None + + gpu_qx_rowwise = x_mxfp8._rowwise_data + gpu_scale_rowwise = x_mxfp8._rowwise_scale_inv + gpu_qx_colwise = x_mxfp8._columnwise_data + gpu_scale_colwise = x_mxfp8._columnwise_scale_inv + + # Reference Quantization + ref_qx_rowwise, ref_scale_rowwise, ref_qx_colwise, ref_scale_colwise = ( + quantize_mxfp8_2d_reference(x) + ) + + num_block_rows = (M + MXFP8_BLOCK_SIZE - 1) // MXFP8_BLOCK_SIZE + num_block_cols = (N + MXFP8_BLOCK_SIZE - 1) // MXFP8_BLOCK_SIZE + + # GPU scales may have padding, compare valid portion + gpu_scale_rowwise_valid = gpu_scale_rowwise[:M, :num_block_cols] + gpu_scale_colwise_valid = gpu_scale_colwise[:num_block_rows, :N] + + # 1. Verify scales match reference + torch.testing.assert_close( + gpu_scale_rowwise_valid, + ref_scale_rowwise, + atol=0, + rtol=0, + ) + + # 2. Verify 32x32 blocks share the same scale + for bi in range(num_block_rows): + for bj in range(num_block_cols): + row_start = bi * MXFP8_BLOCK_SIZE + row_end = min((bi + 1) * MXFP8_BLOCK_SIZE, M) + col_start = bj * MXFP8_BLOCK_SIZE + col_end = min((bj + 1) * MXFP8_BLOCK_SIZE, N) + + # All rows in block should have same scale for this column block + block_rowwise_scales = gpu_scale_rowwise[row_start:row_end, bj] + assert torch.all( + block_rowwise_scales == block_rowwise_scales[0] + ), f"2D mode: Block ({bi},{bj}) rowwise scales should be identical" + + # All columns in block should have same scale for this row block + block_colwise_scales = gpu_scale_colwise[bi, col_start:col_end] + assert torch.all( + block_colwise_scales == block_colwise_scales[0] + ), f"2D mode: Block ({bi},{bj}) colwise scales should be identical" + + # Rowwise and colwise scales should match + assert block_rowwise_scales[0] == block_colwise_scales[0], ( + f"2D mode: Block ({bi},{bj}) rowwise and colwise scales should be equal, " + f"got rowwise={block_rowwise_scales[0]}, colwise={block_colwise_scales[0]}" + ) + + # 3. Verify rowwise and colwise quantized data match reference + # Convert FP8 tensors to uint8 for bitwise comparison + gpu_qx_rowwise_uint8 = gpu_qx_rowwise.view(torch.uint8)[:M, :N] + gpu_qx_colwise_uint8 = gpu_qx_colwise.view(torch.uint8)[:M, :N] + ref_qx_rowwise_uint8 = ref_qx_rowwise.view(torch.uint8) + + torch.testing.assert_close( + gpu_qx_rowwise_uint8, + ref_qx_rowwise_uint8, + atol=0, + rtol=0, + ) + + torch.testing.assert_close( + gpu_qx_colwise_uint8, + ref_qx_rowwise_uint8, + atol=0, + rtol=0, + ) + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +@pytest.mark.parametrize( + "M, N", + [ + # Full tile cases (multiples of 32) + (64, 64), + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 288), + (320, 320), + (352, 256), + # Larger sizes + (2048, 2048), + (1024, 2048), + (2048, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_mxfp8_2d_quantization_versus_reference( + M: int, + N: int, + x_dtype: torch.dtype, + use_cpp_allocator: bool, +) -> None: + """Test MXFP8 2D quantization against reference implementation.""" + check_mxfp8_2d_quantization_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + use_cpp_allocator=use_cpp_allocator, + ) + + +# ============================================================================ +# Recipe Configuration Tests +# ============================================================================ + + +class TestMXFP8BlockScalingRecipe: + """Tests for MXFP8BlockScaling recipe configuration.""" + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_default_recipe_has_qparams(self): + """Test that default MXFP8BlockScaling has QParams attributes.""" + mxfp8_recipe = MXFP8BlockScaling() + + # Verify QParams attributes exist + assert hasattr(mxfp8_recipe, "fp8_quant_fwd_inp") + assert hasattr(mxfp8_recipe, "fp8_quant_fwd_weight") + assert hasattr(mxfp8_recipe, "fp8_quant_bwd_grad") + + # Verify they are QParams instances + assert isinstance(mxfp8_recipe.fp8_quant_fwd_inp, QParams) + assert isinstance(mxfp8_recipe.fp8_quant_fwd_weight, QParams) + assert isinstance(mxfp8_recipe.fp8_quant_bwd_grad, QParams) + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_default_2d_quantization_disabled(self): + """Test that 2D quantization is disabled by default.""" + mxfp8_recipe = MXFP8BlockScaling() + + # By default, 2D quantization should be disabled + assert mxfp8_recipe.enable_2d_quantization is False + + # QParams should reflect this + assert mxfp8_recipe.fp8_quant_fwd_inp.mxfp8_2d_quantization is False + assert mxfp8_recipe.fp8_quant_fwd_weight.mxfp8_2d_quantization is False + assert mxfp8_recipe.fp8_quant_bwd_grad.mxfp8_2d_quantization is False + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_2d_quantization_enabled_only_for_weight(self): + """Test that when 2D quantization is enabled, it only applies to weight.""" + # Create recipe with 2D quantization enabled + mxfp8_recipe = MXFP8BlockScaling(enable_2d_quantization=True) + + # enable_2d_quantization should be True + assert mxfp8_recipe.enable_2d_quantization is True + + # Only weight should have 2D quantization enabled + assert mxfp8_recipe.fp8_quant_fwd_inp.mxfp8_2d_quantization is False + assert mxfp8_recipe.fp8_quant_fwd_weight.mxfp8_2d_quantization is True + assert mxfp8_recipe.fp8_quant_bwd_grad.mxfp8_2d_quantization is False + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_qparams_default_values(self): + """Test that QParams have correct default values for MXFP8.""" + mxfp8_recipe = MXFP8BlockScaling() + + # Check default values for all QParams + for qparams in [ + mxfp8_recipe.fp8_quant_fwd_inp, + mxfp8_recipe.fp8_quant_fwd_weight, + mxfp8_recipe.fp8_quant_bwd_grad, + ]: + # These should use defaults for MXFP8 + assert qparams.power_2_scale is False # MXFP8 uses E8M0, inherently power of 2 + assert qparams.amax_epsilon == 0.0 + assert qparams.random_hadamard_transform is False + assert qparams.stochastic_rounding is False + assert qparams.fp4_2d_quantization is False # Not applicable to MXFP8 + assert qparams.mxfp8_2d_quantization is False # Default is False + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_recipe_repr_includes_2d_quantization(self): + """Test that recipe __repr__ includes 2D quantization status.""" + mxfp8_recipe_disabled = MXFP8BlockScaling(enable_2d_quantization=False) + mxfp8_recipe_enabled = MXFP8BlockScaling(enable_2d_quantization=True) + + repr_disabled = repr(mxfp8_recipe_disabled) + repr_enabled = repr(mxfp8_recipe_enabled) + + assert "enable_2d_quantization=False" in repr_disabled + assert "enable_2d_quantization=True" in repr_enabled + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_mxfp8_quantizer_respects_2d_flag(): + """Test that MXFP8Quantizer correctly uses the 2D quantization flag from recipe.""" + # Test with 2D disabled + quantizer_1d = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + with_2d_quantization=False, + ) + assert quantizer_1d.with_2d_quantization is False + + # Test with 2D enabled + quantizer_2d = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + with_2d_quantization=True, + ) + assert quantizer_2d.with_2d_quantization is True + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_mxfp8_recipe_state_creates_correct_quantizers(): + """Test that MXFP8BlockScalingRecipeState creates quantizers with correct 2D settings.""" + from transformer_engine.pytorch.quantization import MXFP8BlockScalingRecipeState + + # Test with 2D disabled + recipe_1d = MXFP8BlockScaling(enable_2d_quantization=False) + state_fwd_1d = MXFP8BlockScalingRecipeState( + recipe=recipe_1d, + mode="forward", + num_quantizers=3, # input, weight, output + ) + quantizers_1d = state_fwd_1d.make_quantizers() + + # All quantizers should have 2D disabled + for idx, q in enumerate(quantizers_1d): + assert q.with_2d_quantization is False, f"Quantizer {idx} should have 2D disabled" + + # Test with 2D enabled + recipe_2d = MXFP8BlockScaling(enable_2d_quantization=True) + state_fwd_2d = MXFP8BlockScalingRecipeState( + recipe=recipe_2d, + mode="forward", + num_quantizers=3, + ) + quantizers_2d = state_fwd_2d.make_quantizers() + + # Only weight (idx % 3 == 1) should have 2D enabled + for idx, q in enumerate(quantizers_2d): + if idx % 3 == 1: # weight + assert q.with_2d_quantization is True, f"Weight quantizer {idx} should have 2D enabled" + else: # input or output + assert ( + q.with_2d_quantization is False + ), f"Non-weight quantizer {idx} should have 2D disabled" diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a02e7f4f07..f1e966f9e0 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -85,7 +85,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, Tensor *dummy_workspace_tensor = nullptr; mxfp8::quantize( *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); + dummy_workspace_tensor, quant_config_cpp.mxfp8_2d_quantization, stream); break; } case NVTE_NVFP4_1D_SCALING: { @@ -223,7 +223,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens case NVTE_MXFP8_1D_SCALING: { mxfp8::quantize( *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); + quant_config_cpp.mxfp8_2d_quantization, stream); break; } case NVTE_NVFP4_1D_SCALING: { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 70a68132ad..da9e5ffd75 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -45,7 +45,7 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 template + size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK, bool kIs2DBlockScaling> __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, @@ -163,6 +163,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[STAGES]; + // Shared memory to pass 2D block scales from colwise to rowwise pass + // THREADS_X = number of 32x32 blocks in X direction + __shared__ e8m0_t block_scales_2d[THREADS_X]; + initialize_barriers(mbar, is_master_thread); int parity = 0; @@ -264,6 +268,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } + if constexpr (kIs2DBlockScaling) { +#pragma unroll + for (int i = 16; i > 0; i /= 2) { + thread_amax = fmaxf(thread_amax, __shfl_xor_sync(0xffffffff, thread_amax, i)); + } + } + // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); @@ -278,6 +289,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } scales_colwise[scale_idx] = biased_exponent; + // In 2D mode, save scale to shared memory for rowwise pass + // Each warp (processing one 32x32 block) writes one scale via lane 0 + if constexpr (kIs2DBlockScaling && ROWWISE_SCALING) { + if (thread_lane == 0) { + block_scales_2d[threadIdx.x / THREADS_PER_WARP] = biased_exponent; + } + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; @@ -300,7 +319,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (ROWWISE_SCALING) { const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; + if constexpr (!kIs2DBlockScaling) { + thread_amax = 0.0f; + } float in_compute_rowwise[SCALE_DIM_X]; Vec in_cached[WAVES]; @@ -317,13 +338,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (!kIs2DBlockScaling) { #pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } } } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + if constexpr (!kIs2DBlockScaling) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } } else if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); @@ -342,25 +367,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { + if constexpr (!kIs2DBlockScaling) { + if (!out_of_bounds) { + if constexpr (std::is_same_v) { #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { #pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } } } } } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + if constexpr (!kIs2DBlockScaling) { + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } } } else { #pragma unroll @@ -397,17 +426,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (!std::is_same_v) { elt = static_cast(static_cast(elt)); } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { + if constexpr (!kIs2DBlockScaling) { + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this thread_amax = fmaxf(thread_amax, fabsf(elt)); } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); } in_compute_rowwise[j] = elt; } @@ -415,8 +447,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + e8m0_t biased_exponent; + if constexpr (kIs2DBlockScaling && COLWISE_SCALING) { + // In 2D mode with both scaling directions, use scale from colwise pass + // Sync to ensure colwise writes to block_scales_2d are visible across warps + __syncthreads(); + e8m0_t scale_from_shmem; + if (thread_lane < THREADS_X) { + scale_from_shmem = block_scales_2d[thread_lane]; + } + // Broadcast: each thread gets scale from lane matching its tid_X_rowwise + biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise); + } else { + biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + } const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_X = scales_offset_X_rowwise; size_t scale_idx; @@ -556,7 +600,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + Tensor *output, Tensor *dbias, Tensor *workspace, const bool use_2d_quantization, + cudaStream_t stream) { using namespace quantize_kernel; checkCuDriverContext(stream); @@ -642,7 +687,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { + !WITH_GEMM_SWIZZLED_SCALES && !use_2d_quantization) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; @@ -774,11 +819,14 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } + if (use_2d_quantization) { scaling_type = ScalingType::BIDIMENSIONAL; } + switch (scaling_type) { case ScalingType::ROWWISE: { - auto kernel = quantize_mxfp8_kernel; + auto kernel = + quantize_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); @@ -791,9 +839,10 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::COLWISE: { - auto kernel = quantize_mxfp8_kernel; + auto kernel = + quantize_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); @@ -806,18 +855,23 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } case ScalingType::BIDIMENSIONAL: { - auto kernel = quantize_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, + noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError());); break; } } diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 970b7aef6c..59d9072f8b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -411,6 +411,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool mxfp8_2d_quantization = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -420,7 +421,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // mxfp8_2d_quantization }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..fc9a8959b3 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -370,6 +370,8 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Whether to use 2D block scaling for MXFP8 */ + kNVTEQuantizationConfigMXFP82DQuantization = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1046,6 +1048,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether to use 2D block scaling for MXFP8 */ + void set_mxfp8_2d_quantization(bool mxfp8_2d_quantization) { + const auto val = static_cast(mxfp8_2d_quantization); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigMXFP82DQuantization, + &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..fee3f2c81e 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -65,6 +65,8 @@ class QParams: amax_epsilon: optional minimum value of abs max random_hadamard_transform: whether to use random hadamard transform stochastic_rounding: whether to use stocastic rounding + fp4_2d_quantization: whether to use 2D block scaling for NVFP4 + mxfp8_2d_quantization: whether to use 2D block scaling for MXFP8 """ power_2_scale: bool = False @@ -72,6 +74,7 @@ class QParams: random_hadamard_transform: bool = False stochastic_rounding: bool = False fp4_2d_quantization: bool = False + mxfp8_2d_quantization: bool = False def __repr__(self) -> str: return ( @@ -79,7 +82,8 @@ def __repr__(self) -> str: f"amax_epsilon={self.amax_epsilon},\n" f"random_hadamard_transform={self.random_hadamard_transform},\n" f"stochastic_rounding={self.stochastic_rounding},\n" - f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + f"fp4_2d_quantization={self.fp4_2d_quantization},\n" + f"mxfp8_2d_quantization={self.mxfp8_2d_quantization}\n)" ) @@ -284,8 +288,13 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + enable_2d_quantization : bool, default = False + If set to `True`, 2D block scaling is used for weight tensors. """ + # Configuration envvars + enable_2d_quantization: bool = os.getenv("NVTE_MXFP8_ENABLE_2D_QUANTIZATION", "0") == "1" + margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False @@ -294,11 +303,17 @@ class MXFP8BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + # Quantization params (same pattern as NVFP4BlockScaling) + self.fp8_quant_fwd_inp = QParams(mxfp8_2d_quantization=False) + self.fp8_quant_fwd_weight = QParams(mxfp8_2d_quantization=self.enable_2d_quantization) + self.fp8_quant_bwd_grad = QParams(mxfp8_2d_quantization=False) + def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"enable_2d_quantization={self.enable_2d_quantization}" ) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 06971443dd..f3ea3d1e4f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1059,6 +1059,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigMXFP82DQuantization: + bool_to_uint8(config_.mxfp8_2d_quantization, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1114,6 +1117,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigMXFP82DQuantization: + uint8_to_bool(buf, config_.mxfp8_2d_quantization); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bc22e03097..cee08ba96f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -264,6 +264,8 @@ class Float8BlockQuantizer : public Quantizer { class MXFP8Quantizer : public Quantizer { public: DType dtype; + // 2D block scaling + bool with_2d_quantization; explicit MXFP8Quantizer(const py::handle& quantizer); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..b537625875 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -853,6 +853,7 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vectordtype = quantizer.attr("dtype").cast(); + this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); } void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} @@ -1059,6 +1060,7 @@ void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); } + quant_config.set_mxfp8_2d_quantization(this->with_2d_quantization); NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..6d63901628 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1169,7 +1169,35 @@ def make_quantizers(self) -> list: # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.mxfp8_tensor import MXFP8Quantizer - return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + if self.mode == "forward": + + def _make_quantizer(idx: int) -> MXFP8Quantizer: + qparams = ( + self.recipe.fp8_quant_fwd_weight + if idx % 3 == 1 + else self.recipe.fp8_quant_fwd_inp + ) + return MXFP8Quantizer( + fp8_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_2d_quantization=qparams.mxfp8_2d_quantization, + ) + + return [_make_quantizer(idx) for idx in range(self.num_quantizers)] + + if self.mode == "backward": + return [ + MXFP8Quantizer( + fp8_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_2d_quantization=self.recipe.fp8_quant_bwd_grad.mxfp8_2d_quantization, + ) + for _ in range(self.num_quantizers) + ] + + raise ValueError(f"Unknown mode: {self.mode}") class Float8BlockScalingRecipeState(RecipeState): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..9000592b3c 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -35,15 +35,20 @@ class MXFP8Quantizer(Quantizer): dtype: TE_DType + """2D block scaling, only applicable for weights.""" + with_2d_quantization: bool + def __init__( self, fp8_dtype: TE_DType, *, rowwise: bool = True, columnwise: bool = True, + with_2d_quantization: bool = False, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) self.dtype = fp8_dtype + self.with_2d_quantization = with_2d_quantization def copy(self) -> MXFP8Quantizer: """Create shallow copy""" @@ -52,6 +57,7 @@ def copy(self) -> MXFP8Quantizer: fp8_dtype=self.dtype, rowwise=self.rowwise_usage, columnwise=self.columnwise_usage, + with_2d_quantization=self.with_2d_quantization, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm