[PyTorch] Add grouped linear op and experimental fusion for grouped MLP#2622
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP#2622timmoon10 wants to merge 52 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Greptile OverviewGreptile SummaryAdds grouped linear operation and experimental fused grouped MLP for Mixture-of-Experts models. The implementation includes a new Key changes:
Issues previously reported have been addressed: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Sequential as te_ops.Sequential
participant GroupedLinear1 as GroupedLinear (FC1)
participant ScaledSwiGLU
participant GroupedLinear2 as GroupedLinear (FC2)
participant FusedOp as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
participant CuTeKernel as CuDNN CuTe DSL Kernel
Note over User,CuTeKernel: Regular Path (No Fusion)
User->>Sequential: forward(input, split_sizes, scales)
Sequential->>GroupedLinear1: forward(input, split_sizes)
GroupedLinear1->>GroupedLinear1: Split input by split_sizes
GroupedLinear1->>GroupedLinear1: Quantize to MXFP8 if enabled
GroupedLinear1->>GroupedLinear1: general_grouped_gemm(weights, inputs)
GroupedLinear1-->>Sequential: fc1_output
Sequential->>ScaledSwiGLU: forward(fc1_output, scales)
ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving if needed
ScaledSwiGLU->>ScaledSwiGLU: Compute SwiGLU activation
ScaledSwiGLU->>ScaledSwiGLU: Apply post-scaling (output * scales)
ScaledSwiGLU-->>Sequential: swiglu_output
Sequential->>GroupedLinear2: forward(swiglu_output, split_sizes)
GroupedLinear2->>GroupedLinear2: Split input by split_sizes
GroupedLinear2->>GroupedLinear2: Quantize to MXFP8 if enabled
GroupedLinear2->>GroupedLinear2: general_grouped_gemm(weights, inputs)
GroupedLinear2-->>Sequential: final_output
Sequential-->>User: final_output
Note over User,CuTeKernel: Fused Path (MXFP8 + SM100+)
User->>Sequential: forward(input, split_sizes, scales)
Sequential->>FusedOp: fuser_forward(input, split_sizes, scales)
FusedOp->>FusedOp: Quantize FC1 inputs to MXFP8
FusedOp->>FusedOp: Pack FC1 data/scales with gate swapping
FusedOp->>CuTeKernel: grouped_gemm_swiglu_wrapper_sm100()
Note right of CuTeKernel: Single kernel:<br/>FC1 GEMM + SwiGLU + post-scale
CuTeKernel-->>FusedOp: FC2 inputs (MXFP8, row+col quantized)
FusedOp->>FusedOp: Unpack FC2 inputs and undo gate swap
FusedOp->>FusedOp: Construct MXFP8Tensor objects
FusedOp->>FusedOp: general_grouped_gemm(FC2 weights, FC2 inputs)
FusedOp-->>Sequential: final_output
Sequential-->>User: final_output
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
| quantizer.optimize_for_gemm = True | ||
| fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) | ||
|
|
||
| # Pack data tensors |
There was a problem hiding this comment.
May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?
There was a problem hiding this comment.
I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245
| ) | ||
|
|
||
| # Fused kernel for FC1 + SwiGLU + post-scale | ||
| fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( |
There was a problem hiding this comment.
After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?
There was a problem hiding this comment.
Yes, the probs are passed into the kernel here: https://github.com/timmoon10/TransformerEngine/blob/46294be478f6551e2cf251283adc7529ddb2964e/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py#L264
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestions from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestion from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
| accumulate_into_main_grad = not getattr( | ||
| weight_param, "overwrite_main_grad", False | ||
| ) |
There was a problem hiding this comment.
accumulate_into_main_grad reassigned in loop - last group's setting applies to all groups in GEMM call on line 576. If different weight groups have different overwrite_main_grad settings, this causes incorrect gradient accumulation behavior. Should either check consistency across groups or use per-group flags.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: