MoE WG kernel: expose tile knobs, add benchmark tune mode, and encode bpe-based heuristic defaults#414
MoE WG kernel: expose tile knobs, add benchmark tune mode, and encode bpe-based heuristic defaults#414Copilot wants to merge 3 commits intoneoblizz/moe-fused-plusfrom
Conversation
Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com>
|
@copilot Run and explore on the MCP. |
Ran the tune sweep on 8x MI300X (304 CUs each, 256 GB VRAM, ROCm 7.1) using Key observations:
To reproduce: HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark/examples/benchmark_moe.py \
--num_ranks 8 \
--fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp \
--tune --benchmark \
--batch_per_expt 64 128 256 512 1024 \
--block_m_values 32 64 128 256 \
--gemm_sms_values 76 152 228 \
--output_file moe_wg_tune_full.jsonFull per-combo timing data is saved in the output JSON under |
|
@copilot Encode the best configs per bpe ranges (if that heuristic makes sense) as default behavior. Also, present a plot of latency results with various configs marking the best one. |
Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com>
Done in commit Heuristic defaults encoded in
Latency plot across all (gemm_sms, block_m) configs with best per bpe marked (red ★): Key takeaways visible in the plot:
|

The WG-specialized EP→DP fused kernel had hardcoded
BLOCK_M/N/Kwith no way to explore the optimal tile shape or GEMM/comm SM split for different batch-per-expert (bpe) sizes.Changes
fused_exp_matmul_ep_to_dp_wg.pywg_fused_exp_matmul_ep_to_dp()now accepts optionalblock_m,block_n,block_k— defaults are now selected by a bpe-aware heuristic.New
_heuristic_wg_config(num_sms, avg_bpe)function encodes optimal(gemm_sms, block_m)defaults derived from a MI300X tune sweep (304 CUs, 8 ranks, bpe ∈ {64, 128, 256, 512, 1024}):gemm_smsblock_mnum_sms // 2(50%)3 × num_sms // 4(75%)3 × num_sms // 4(75%)avg_bpeis computed asn_slots_per_rank // n_local_experts. Explicitgemm_sms=/block_m=arguments still override as before. The old2**floor(log2(cu_count))default forgemm_smsis replaced.moe.pyblock_m,block_n,block_kthroughmixture_of_expt_epsharded()into the WG kernel call.benchmark_moe.py--tune: for each bpe point, sweeps all(gemm_sms, block_m)combos with a quick pass (5 warmup / 20 repeat), selects the fastest, then runs the full benchmark with that config.--gemm_sms_values: explicit SM split candidates (default: ¼, ½, ¾ of total SMs, clamped to[1, num_sms-1]).--block_m_values: explicit M-tile candidates (default:[32, 64, 128, 256]).tune_best_gemm_sms,tune_best_block_m,tune_configs.Tune Results (MI300X, 8 ranks, 128 experts, d_model=5760, bf16)
Key observations:
block_m=32is always worst; dedicating 75% of SMs to GEMM is optimal at bpe≥128;block_mshould increase from 128→256 at bpe>128 as GEMM becomes more compute-bound.HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark/examples/benchmark_moe.py \ --num_ranks 8 \ --fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp \ --tune --benchmark \ --batch_per_expt 64 128 256 512 1024 \ --block_m_values 32 64 128 256 \ --output_file moe_wg_tune.json💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.