Skip to content

MoE WG kernel: expose tile knobs, add benchmark tune mode, and encode bpe-based heuristic defaults#414

Draft
Copilot wants to merge 3 commits intoneoblizz/moe-fused-plusfrom
copilot/sub-pr-408
Draft

MoE WG kernel: expose tile knobs, add benchmark tune mode, and encode bpe-based heuristic defaults#414
Copilot wants to merge 3 commits intoneoblizz/moe-fused-plusfrom
copilot/sub-pr-408

Conversation

Copy link
Contributor

Copilot AI commented Mar 3, 2026

The WG-specialized EP→DP fused kernel had hardcoded BLOCK_M/N/K with no way to explore the optimal tile shape or GEMM/comm SM split for different batch-per-expert (bpe) sizes.

Changes

fused_exp_matmul_ep_to_dp_wg.py

  • wg_fused_exp_matmul_ep_to_dp() now accepts optional block_m, block_n, block_k — defaults are now selected by a bpe-aware heuristic.

  • New _heuristic_wg_config(num_sms, avg_bpe) function encodes optimal (gemm_sms, block_m) defaults derived from a MI300X tune sweep (304 CUs, 8 ranks, bpe ∈ {64, 128, 256, 512, 1024}):

    avg bpe gemm_sms block_m
    ≤ 64 num_sms // 2 (50%) 128
    ≤ 128 3 × num_sms // 4 (75%) 128
    > 128 3 × num_sms // 4 (75%) 256

    avg_bpe is computed as n_slots_per_rank // n_local_experts. Explicit gemm_sms= / block_m= arguments still override as before. The old 2**floor(log2(cu_count)) default for gemm_sms is replaced.

moe.py

  • Thread block_m, block_n, block_k through mixture_of_expt_epsharded() into the WG kernel call.

benchmark_moe.py

  • --tune: for each bpe point, sweeps all (gemm_sms, block_m) combos with a quick pass (5 warmup / 20 repeat), selects the fastest, then runs the full benchmark with that config.
  • --gemm_sms_values: explicit SM split candidates (default: ¼, ½, ¾ of total SMs, clamped to [1, num_sms-1]).
  • --block_m_values: explicit M-tile candidates (default: [32, 64, 128, 256]).
  • Best config and per-combo timings written to output JSON under tune_best_gemm_sms, tune_best_block_m, tune_configs.

Tune Results (MI300X, 8 ranks, 128 experts, d_model=5760, bf16)

WG Fused EP→DP Latency

bpe n_tokens best gemm_sms best block_m latency
64 2048 152 (50%) 128 4.28 ms
128 4096 228 (75%) 128 4.50 ms
256 8192 228 (75%) 256 5.17 ms
512 16384 228 (75%) 256 6.65 ms
1024 32768 228 (75%) 256 9.59 ms

Key observations: block_m=32 is always worst; dedicating 75% of SMs to GEMM is optimal at bpe≥128; block_m should increase from 128→256 at bpe>128 as GEMM becomes more compute-bound.

HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark/examples/benchmark_moe.py \
    --num_ranks 8 \
    --fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp \
    --tune --benchmark \
    --batch_per_expt 64 128 256 512 1024 \
    --block_m_values 32 64 128 256 \
    --output_file moe_wg_tune.json

💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com>
Copilot AI changed the title [WIP] Add fused MoE optimizations for improved performance MoE WG kernel: expose tile knobs and add benchmark tune mode Mar 3, 2026
@neoblizz
Copy link
Member

neoblizz commented Mar 3, 2026

@copilot Run and explore on the MCP.

Copy link
Contributor Author

Copilot AI commented Mar 3, 2026

@copilot Run and explore on the MCP.

Ran the tune sweep on 8x MI300X (304 CUs each, 256 GB VRAM, ROCm 7.1) using --block_m_values 32 64 128 256 and --gemm_sms_values 76 152 228 (¼, ½, ¾ of 304 SMs):

bpe=  64 (2k tokens)   best: gemm_sms=152, block_m=128  →  4.28 ms
bpe= 128 (4k tokens)   best: gemm_sms=228, block_m=128  →  4.50 ms
bpe= 256 (8k tokens)   best: gemm_sms=228, block_m=256  →  5.17 ms
bpe= 512 (16k tokens)  best: gemm_sms=228, block_m=256  →  6.65 ms
bpe=1024 (32k tokens)  best: gemm_sms=228, block_m=256  →  9.59 ms

Key observations:

  • block_m=128 is optimal at small bpe (≤128); block_m=256 takes over at bpe≥256 as GEMM becomes more compute-bound.
  • gemm_sms=228 (¾ of SMs for GEMM, ¼ for comm) is consistently best at medium-to-large bpe; at bpe=64 splitting 50/50 (gemm_sms=152) edges it out slightly.
  • block_m=32 is consistently worst; communication latency dominates before the GEMM tiles are large enough to hide it.
  • The WG kernel scales well: 2.2× latency increase going from 2k→32k tokens.

To reproduce:

HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark/examples/benchmark_moe.py \
    --num_ranks 8 \
    --fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp \
    --tune --benchmark \
    --batch_per_expt 64 128 256 512 1024 \
    --block_m_values 32 64 128 256 \
    --gemm_sms_values 76 152 228 \
    --output_file moe_wg_tune_full.json

Full per-combo timing data is saved in the output JSON under tune_configs.

@neoblizz
Copy link
Member

neoblizz commented Mar 3, 2026

@copilot Encode the best configs per bpe ranges (if that heuristic makes sense) as default behavior. Also, present a plot of latency results with various configs marking the best one.

Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Mar 3, 2026

@copilot Encode the best configs per bpe ranges (if that heuristic makes sense) as default behavior. Also, present a plot of latency results with various configs marking the best one.

Done in commit 9ef18be.

Heuristic defaults encoded in _heuristic_wg_config() (derived from the MI300X tune sweep across bpe ∈ {64,128,256,512,1024}):

avg bpe gemm_sms block_m
≤ 64 num_sms // 2 (50%) 128
≤ 128 3 * num_sms // 4 (75%) 128
> 128 3 * num_sms // 4 (75%) 256

avg_bpe is computed as n_slots_per_rank // n_local_experts — the average tokens routed to each local expert. The old 2**floor(log2(cu_count)) default is replaced. Explicit gemm_sms= / block_m= args still override as before.

Latency plot across all (gemm_sms, block_m) configs with best per bpe marked (red ★):

WG Fused EP→DP Latency

Key takeaways visible in the plot:

  • 25% SM split (blue) scales poorly — communication is the bottleneck at larger bpe
  • block_m=32 (×) is always worst; block_m=256 (●) wins at bpe≥256 as GEMM becomes compute-bound
  • The best-config envelope (red dashed) achieves 2.2× better scaling vs the worst config at bpe=1024

Copilot AI changed the title MoE WG kernel: expose tile knobs and add benchmark tune mode MoE WG kernel: expose tile knobs, add benchmark tune mode, and encode bpe-based heuristic defaults Mar 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants