Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
444 changes: 444 additions & 0 deletions tests/pytorch/test_mxfp8_2d_quantize.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
Tensor *dummy_workspace_tensor = nullptr;
mxfp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
*input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
dummy_workspace_tensor, stream);
dummy_workspace_tensor, quant_config_cpp.mxfp8_2d_quantization, stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
Expand Down Expand Up @@ -223,7 +223,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
case NVTE_MXFP8_1D_SCALING: {
mxfp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
*grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
quant_config_cpp.mxfp8_2d_quantization, stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
Expand Down
154 changes: 104 additions & 50 deletions transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES, size_t CHUNK_DIM_Y,
size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK, bool kIs2DBlockScaling>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input,
Expand Down Expand Up @@ -163,6 +163,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];

// Shared memory to pass 2D block scales from colwise to rowwise pass
// THREADS_X = number of 32x32 blocks in X direction
__shared__ e8m0_t block_scales_2d[THREADS_X];

initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);

int parity = 0;
Expand Down Expand Up @@ -264,6 +268,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}

if constexpr (kIs2DBlockScaling) {
#pragma unroll
for (int i = 16; i > 0; i /= 2) {
thread_amax = fmaxf(thread_amax, __shfl_xor_sync(0xffffffff, thread_amax, i));
}
}

// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
Expand All @@ -278,6 +289,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
scales_colwise[scale_idx] = biased_exponent;

// In 2D mode, save scale to shared memory for rowwise pass
// Each warp (processing one 32x32 block) writes one scale via lane 0
if constexpr (kIs2DBlockScaling && ROWWISE_SCALING) {
if (thread_lane == 0) {
block_scales_2d[threadIdx.x / THREADS_PER_WARP] = biased_exponent;
}
}

const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};

Expand All @@ -300,7 +319,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (ROWWISE_SCALING) {
const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
thread_amax = 0.0f;
if constexpr (!kIs2DBlockScaling) {
thread_amax = 0.0f;
}
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];

Expand All @@ -317,13 +338,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
if constexpr (!kIs2DBlockScaling) {
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
if constexpr (!kIs2DBlockScaling) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
Expand All @@ -342,25 +367,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
if constexpr (!kIs2DBlockScaling) {
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
if constexpr (!kIs2DBlockScaling) {
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
}
} else {
#pragma unroll
Expand Down Expand Up @@ -397,26 +426,41 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
if constexpr (!kIs2DBlockScaling) {
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}

// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
e8m0_t biased_exponent;
if constexpr (kIs2DBlockScaling && COLWISE_SCALING) {
// In 2D mode with both scaling directions, use scale from colwise pass
// Sync to ensure colwise writes to block_scales_2d are visible across warps
__syncthreads();
e8m0_t scale_from_shmem;
if (thread_lane < THREADS_X) {
scale_from_shmem = block_scales_2d[thread_lane];
}
// Broadcast: each thread gets scale from lane matching its tid_X_rowwise
biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise);
Comment on lines +455 to +460
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale_from_shmem is potentially uninitialized for threads where thread_lane >= THREADS_X. While __shfl_sync only reads from lanes specified by tid_X_rowwise (which should be < THREADS_X), it's safer to initialize this variable.

Suggested change
e8m0_t scale_from_shmem;
if (thread_lane < THREADS_X) {
scale_from_shmem = block_scales_2d[thread_lane];
}
// Broadcast: each thread gets scale from lane matching its tid_X_rowwise
biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise);
e8m0_t scale_from_shmem = 0;
if (thread_lane < THREADS_X) {
scale_from_shmem = block_scales_2d[thread_lane];
}

} else {
biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
}
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
size_t scale_idx;
Expand Down Expand Up @@ -556,7 +600,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
Tensor *output, Tensor *dbias, Tensor *workspace, const bool use_2d_quantization,
cudaStream_t stream) {
using namespace quantize_kernel;
checkCuDriverContext(stream);

Expand Down Expand Up @@ -642,7 +687,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,

if (specialized::hasSpec<IS_DBIAS, IS_DACT, IS_ACT, IType, OType>() &&
!WITH_GEMM_SWIZZLED_SCALES) {
!WITH_GEMM_SWIZZLED_SCALES && !use_2d_quantization) {
switch (scaling_type) {
case ScalingType::ROWWISE: {
using traits = specialized::CastTraits<IType, OType, true, false>;
Expand Down Expand Up @@ -774,11 +819,14 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
}
}

if (use_2d_quantization) { scaling_type = ScalingType::BIDIMENSIONAL; }

switch (scaling_type) {
case ScalingType::ROWWISE: {
auto kernel = quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false, WITH_GEMM_SWIZZLED_SCALES,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
auto kernel =
quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
true, false, WITH_GEMM_SWIZZLED_SCALES, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK, false>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));

Expand All @@ -791,9 +839,10 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
break;
}
case ScalingType::COLWISE: {
auto kernel = quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true, WITH_GEMM_SWIZZLED_SCALES,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
auto kernel =
quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
false, true, WITH_GEMM_SWIZZLED_SCALES, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK, false>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));

Expand All @@ -806,18 +855,23 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
break;
}
case ScalingType::BIDIMENSIONAL: {
auto kernel = quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true, WITH_GEMM_SWIZZLED_SCALES,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));

kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_2d_quantization, kIs2DBlockScaling,

auto kernel =
quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true, WITH_GEMM_SWIZZLED_SCALES,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK,
kIs2DBlockScaling>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));

kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr,
noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError()););
break;
}
}
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ struct QuantizationConfig {
bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false;
bool use_fast_math = false;
bool mxfp8_2d_quantization = false;

static constexpr size_t attr_sizes[] = {
sizeof(uint8_t), // force_pow_2_scales
Expand All @@ -420,7 +421,8 @@ struct QuantizationConfig {
sizeof(NVTETensor), // rng_seed and offset
sizeof(uint8_t), // nvfp4_2d_quantization
sizeof(uint8_t), // stochastic_rounding
sizeof(uint8_t) // use_fast_math
sizeof(uint8_t), // use_fast_math
sizeof(uint8_t) // mxfp8_2d_quantization
};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ enum NVTEQuantizationConfigAttribute {
* inconsistently between kernels.
*/
kNVTEQuantizationConfigUseFastMath = 7,
/*! Whether to use 2D block scaling for MXFP8 */
kNVTEQuantizationConfigMXFP82DQuantization = 8,
kNVTEQuantizationConfigNumAttributes
};

Expand Down Expand Up @@ -1046,6 +1048,13 @@ class QuantizationConfigWrapper {
sizeof(val));
}

/*! \brief Set whether to use 2D block scaling for MXFP8 */
void set_mxfp8_2d_quantization(bool mxfp8_2d_quantization) {
const auto val = static_cast<uint8_t>(mxfp8_2d_quantization);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigMXFP82DQuantization,
&val, sizeof(val));
}

private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
Expand Down
19 changes: 17 additions & 2 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,25 @@ class QParams:
amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
fp4_2d_quantization: whether to use 2D block scaling for NVFP4
mxfp8_2d_quantization: whether to use 2D block scaling for MXFP8
"""

power_2_scale: bool = False
amax_epsilon: float = 0.0
random_hadamard_transform: bool = False
stochastic_rounding: bool = False
fp4_2d_quantization: bool = False
mxfp8_2d_quantization: bool = False

def __repr__(self) -> str:
return (
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
f"fp4_2d_quantization={self.fp4_2d_quantization},\n"
f"mxfp8_2d_quantization={self.mxfp8_2d_quantization}\n)"
)


Expand Down Expand Up @@ -284,8 +288,13 @@ class MXFP8BlockScaling(Recipe):
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
enable_2d_quantization : bool, default = False
If set to `True`, 2D block scaling is used for weight tensors.
"""

# Configuration envvars
enable_2d_quantization: bool = os.getenv("NVTE_MXFP8_ENABLE_2D_QUANTIZATION", "0") == "1"

margin: int = 0
fp8_format: Format = Format.E4M3
fp8_dpa: bool = False
Expand All @@ -294,11 +303,17 @@ class MXFP8BlockScaling(Recipe):
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."

# Quantization params (same pattern as NVFP4BlockScaling)
self.fp8_quant_fwd_inp = QParams(mxfp8_2d_quantization=False)
self.fp8_quant_fwd_weight = QParams(mxfp8_2d_quantization=self.enable_2d_quantization)
self.fp8_quant_bwd_grad = QParams(mxfp8_2d_quantization=False)

def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
f"format={str(self.fp8_format).split('.')[1]}, "
f"enable_2d_quantization={self.enable_2d_quantization}"
)


Expand Down
Loading