Conversation
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds 2D block scaling support for MXFP8 quantization, where each 32x32 block of elements shares a single scaling factor (compared to 1D scaling where each row/column has its own scale). Key Changes:
Architecture: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant Recipe as MXFP8BlockScaling
participant RecipeState as MXFP8BlockScalingRecipeState
participant Quantizer as MXFP8Quantizer (Python)
participant CPPQuantizer as MXFP8Quantizer (C++)
participant Config as QuantizationConfig
participant Dispatch as quantize_fwd_helper
participant Kernel as quantize_mxfp8_kernel
User->>Recipe: MXFP8BlockScaling(enable_2d_quantization=True)
Recipe->>Recipe: __post_init__(): set QParams.mxfp8_2d_quantization
Note over Recipe: fp8_quant_fwd_weight.mxfp8_2d_quantization = True<br/>fp8_quant_fwd_inp.mxfp8_2d_quantization = False<br/>fp8_quant_bwd_grad.mxfp8_2d_quantization = False
User->>RecipeState: make_quantizers()
RecipeState->>RecipeState: Check QParams for each tensor type
RecipeState->>Quantizer: MXFP8Quantizer(with_2d_quantization=qparams.mxfp8_2d_quantization)
RecipeState-->>User: Return list of quantizers
User->>Quantizer: quantizer(input_tensor)
Quantizer->>CPPQuantizer: quantize(input, out)
CPPQuantizer->>Config: set_mxfp8_2d_quantization(with_2d_quantization)
CPPQuantizer->>Dispatch: nvte_quantize_v2(input, output, config)
Dispatch->>Kernel: Call quantize<...>(use_2d_quantization=config.mxfp8_2d_quantization)
alt 2D Quantization Enabled (kIs2DBlockScaling=true)
Kernel->>Kernel: Colwise pass: Compute 32x32 block amax via warp shuffle
Note over Kernel: Each warp reduces across 32 threads<br/>to get single scale per 32x32 block
Kernel->>Kernel: Store block scale to shared memory (block_scales_2d)
Kernel->>Kernel: Quantize colwise data with block scale
Kernel->>Kernel: __syncthreads() before rowwise pass
Kernel->>Kernel: Rowwise pass: Load scale from shared memory
Note over Kernel: Use __shfl_sync to broadcast scale<br/>from shared memory to all threads in warp
Kernel->>Kernel: Quantize rowwise data with same block scale
else 1D Quantization (kIs2DBlockScaling=false)
Kernel->>Kernel: Colwise pass: Compute per-column scale
Kernel->>Kernel: Rowwise pass: Compute per-row scale
end
Kernel-->>CPPQuantizer: Quantized tensor with scales
CPPQuantizer-->>User: MXFP8Tensor with rowwise/colwise data and scales
|
| e8m0_t scale_from_shmem; | ||
| if (thread_lane < THREADS_X) { | ||
| scale_from_shmem = block_scales_2d[thread_lane]; | ||
| } | ||
| // Broadcast: each thread gets scale from lane matching its tid_X_rowwise | ||
| biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise); |
There was a problem hiding this comment.
scale_from_shmem is potentially uninitialized for threads where thread_lane >= THREADS_X. While __shfl_sync only reads from lanes specified by tid_X_rowwise (which should be < THREADS_X), it's safer to initialize this variable.
| e8m0_t scale_from_shmem; | |
| if (thread_lane < THREADS_X) { | |
| scale_from_shmem = block_scales_2d[thread_lane]; | |
| } | |
| // Broadcast: each thread gets scale from lane matching its tid_X_rowwise | |
| biased_exponent = __shfl_sync(0xffffffff, scale_from_shmem, tid_X_rowwise); | |
| e8m0_t scale_from_shmem = 0; | |
| if (thread_lane < THREADS_X) { | |
| scale_from_shmem = block_scales_2d[thread_lane]; | |
| } |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: