[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel#2630
Open
Oleg-Goncharov wants to merge 21 commits intoNVIDIA:mainfrom
Open
[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel#2630Oleg-Goncharov wants to merge 21 commits intoNVIDIA:mainfrom
Oleg-Goncharov wants to merge 21 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Contributor
Greptile OverviewGreptile SummaryThis PR extends the grouped MXFP8 quantization kernel to support pre-swizzled scaling factors by adding a Key Changes
Implementation DetailsThe kernel uses the
The implementation maintains feature parity with the base kernel, supporting activations (GeLU, SiLU, ReLU), activation derivatives, and dbias computation. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant API as nvte_group_quantize
participant Dispatch as group_quantize_fwd_helper
participant Kernel as group_quantize_mxfp8_kernel
participant Swizzle as gemm_swizzled_scale_idx
User->>API: Call nvte_group_quantize(input, output, stream)
API->>Dispatch: group_quantize_fwd_helper<IS_ACT, OP>()
Dispatch->>Dispatch: Check scaling_mode (MXFP8_1D_SCALING)
Dispatch->>Kernel: mxfp8::group_quantize(input, output, ...)
Kernel->>Kernel: Read with_gemm_swizzled_scales from output->with_gemm_swizzled_scales
Kernel->>Kernel: Instantiate kernel with WITH_GEMM_SWIZZLED_SCALES template parameter
alt Multiple tensors (not single tensor)
Kernel->>Kernel: Launch update_tma_descriptors kernel
Kernel->>Kernel: Update tensor map descriptors per tensor
end
Kernel->>Kernel: Launch group_quantize_mxfp8_kernel<<<grid, block>>>
loop For each tile in tensor
Kernel->>Kernel: Load data via TMA
Kernel->>Kernel: Compute activations (if IS_ACT or IS_DACT)
alt Colwise Scaling
Kernel->>Kernel: Compute column-wise amax
Kernel->>Kernel: Convert to E8M0 scaling factor
alt WITH_GEMM_SWIZZLED_SCALES
Kernel->>Swizzle: gemm_swizzled_scale_idx(x, y, num_tiles)
Swizzle-->>Kernel: Return swizzled index
else No swizzling
Kernel->>Kernel: Use compact index (y * stride + x)
end
Kernel->>Kernel: Store scale at computed index
Kernel->>Kernel: Apply scale and quantize to MXFP8
end
alt Rowwise Scaling
Kernel->>Kernel: Compute row-wise amax
Kernel->>Kernel: Convert to E8M0 scaling factor
alt WITH_GEMM_SWIZZLED_SCALES
Kernel->>Swizzle: gemm_swizzled_scale_idx(y, x, num_tiles)
Swizzle-->>Kernel: Return swizzled index
else No swizzling
Kernel->>Kernel: Use compact index (y * stride + x)
end
Kernel->>Kernel: Store scale at computed index
Kernel->>Kernel: Apply scale and quantize to MXFP8
end
Kernel->>Kernel: Store quantized data via TMA
end
alt IS_DBIAS
Kernel->>Kernel: Reduce dbias along columns
end
Kernel-->>User: Return quantized output with swizzled scales
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR fuses pre-swizzling into the grouped MXFP8 quantization kernel so that scaling factors are stored in the format expected by GEMM. It builds on PR#2586: [Common] MXFP8 kernel for grouped tensors and can be merged after that PR lands.
Type of change
Changes
GroupedTensorChecklist: