Skip to content

Fused MoE Optimizations#408

Open
neoblizz wants to merge 5 commits intomainfrom
neoblizz/moe-fused-plus
Open

Fused MoE Optimizations#408
neoblizz wants to merge 5 commits intomainfrom
neoblizz/moe-fused-plus

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Mar 1, 2026

Motivation

... WIP, don't merge.

Submission Checklist

Copilot AI review requested due to automatic review settings March 1, 2026 18:47
@neoblizz neoblizz requested review from BKP and mawad-amd as code owners March 1, 2026 18:47
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR expands expert-sharded MoE “fusion modes” and introduces new fused Triton kernels, alongside Iris pointer-translation vectorization hints, to improve distributed MoE performance.

Changes:

  • Add new fusion modes and kernels: DP→EP gather+GEMM fusion and a WG-specialized EP→DP GEMM+scatter fusion.
  • Add optional hint plumbing through Iris load/store/copy and apply vectorization hints in MoE kernels.
  • Extend benchmarks/tests to exercise additional fusion modes and expose --gemm_sms tuning.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/examples/test_expert_sharded_moe.py Expands parametrized fusion-mode test matrix.
iris/iris.py Adds hint constexpr to translation + load/store/copy/atomics to enable vectorization hints.
examples/31_expert_sharded_moe/moe.py Adds new fusion modes, selection logic, and gemm_sms threading into WG fused path.
examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py New WG-specialized persistent kernel overlapping GEMM and scatter via locks.
examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py Adds Iris store hint + extra kernel launch meta-params.
examples/31_expert_sharded_moe/fused_dp_to_ep_matmul.py New fused DP→EP gather + expert matmul kernel (remote iris.load prologue).
examples/31_expert_sharded_moe/dispatch.py Adds vectorization hints for offsets and iris.store.
examples/31_expert_sharded_moe/combine.py Adds vectorization hints for offsets and iris.store.
benchmark/examples/benchmark_moe.py Adds WG fusion mode + --gemm_sms and enhances breakdown reporting.

Comment on lines 172 to 182
def mode_name(self) -> str:
parts: list[str] = []
if self.fuse_convert_dp_to_ep_grouped_matmul:
parts.append("convert_dp_to_ep_grouped_matmul")
if self.fuse_grouped_matmul_convert_ep_to_dp:
parts.append("grouped_matmul_convert_ep_to_dp")
if self.fuse_grouped_matmul_convert_ep_to_dp_wg:
parts.append("wg_grouped_matmul_convert_ep_to_dp")
if not parts:
return "unfused"
return "fused_" + "__".join(parts)
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mode_name() produces "fused_wg_grouped_matmul_convert_ep_to_dp" for the WG mode, but from_mode_name() / CLI / tests use "wg_fused_grouped_matmul_convert_ep_to_dp". This breaks round-tripping and can confuse logs/config serialization; align the string naming (either change mode_name() output for the WG flag, or accept both spellings in from_mode_name()).

Copilot uses AI. Check for mistakes.
fusion_config.fuse_grouped_matmul_convert_ep_to_dp_wg,
]
)
if n_fusions_active > 1:
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new restriction conflicts with the existing “combined” fusion mode name that from_mode_name() still constructs (fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp) and that the benchmark CLI still offers. Either remove/disable the combined mode in from_mode_name() and CLI choices, or extend the implementation to support the combined mode (and adjust this guard accordingly) so user-selected modes don’t fail at runtime.

Suggested change
if n_fusions_active > 1:
# Allow the documented "combined" fusion mode
combined_mode = (
fusion_config.fuse_convert_dp_to_ep_grouped_matmul
and fusion_config.fuse_grouped_matmul_convert_ep_to_dp
and not fusion_config.fuse_grouped_matmul_convert_ep_to_dp_wg
)
if n_fusions_active > 1 and not combined_mode:

Copilot uses AI. Check for mistakes.
"fused_grouped_matmul_convert_ep_to_dp",
"fused_convert_dp_to_ep_grouped_matmul",
"wg_fused_grouped_matmul_convert_ep_to_dp",
"fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp",
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CLI still advertises the combined fusion mode (fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp), but the MoE pipeline now raises if more than one fusion flag is enabled. Remove this choice (and any corresponding parsing) or make it a supported mode; otherwise benchmark runs can error immediately.

Suggested change
"fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp",

Copilot uses AI. Check for mistakes.
Comment on lines +247 to +249
num_sms = cu_count
if gemm_sms is None:
gemm_sms = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current default, gemm_sms equals cu_count when cu_count is a power-of-two, which makes COMM_SMS = NUM_SMS - GEMM_SMS become 0 in the kernel. That results in no communication programs running (and/or invalid loop strides), so outputs won’t be scattered. Clamp gemm_sms to [1, num_sms - 1] (or choose a safer default split, e.g. num_sms // 2) and error clearly if the user passes an invalid value.

Suggested change
num_sms = cu_count
if gemm_sms is None:
gemm_sms = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1
num_sms = max(int(cu_count), 1)
# Split SMs between GEMM and communication paths.
# Require at least 2 SMs so that we can run both roles.
if num_sms < 2:
raise RuntimeError(
f"_wg_fused_exp_matmul_ep_to_dp_kernel requires at least 2 SMs, "
f"but found {num_sms} on device {device}."
)
if gemm_sms is None:
# Use a safer default than the next power-of-two, which could equal num_sms.
# Start from an even split and clamp to [1, num_sms - 1].
default_gemm_sms = max(1, num_sms // 2)
gemm_sms = min(default_gemm_sms, num_sms - 1)
else:
# Validate user-provided gemm_sms.
if not isinstance(gemm_sms, int):
raise TypeError(
f"gemm_sms must be an int in [1, num_sms - 1], got {type(gemm_sms)}."
)
if gemm_sms < 1 or gemm_sms >= num_sms:
raise ValueError(
f"gemm_sms must be in [1, num_sms - 1] (num_sms={num_sms}), "
f"but got gemm_sms={gemm_sms}."
)

Copilot uses AI. Check for mistakes.

tl.debug_barrier()
lock_idx = en_pair * max_m_tiles + pid_m
tl.store(locks_ptr + lock_idx, 1, cache_modifier=".wt")
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel uses a plain global tl.store to publish the lock, while other programs spin on tl.load(... volatile=True); this does not guarantee a proper release/acquire ordering between the y_buf write and lock visibility across CUs. Consider using an atomic publication for the lock (e.g., tl.atomic_xchg(..., sem="release")) and an acquire read on the consumer side (e.g., atomic read via tl.atomic_add(ptr, 0, sem="acquire")) to avoid reading stale y_buf data.

Suggested change
tl.store(locks_ptr + lock_idx, 1, cache_modifier=".wt")
tl.atomic_xchg(locks_ptr + lock_idx, 1, sem="release")

Copilot uses AI. Check for mistakes.
Comment on lines +192 to +203
def wg_fused_exp_matmul_ep_to_dp(
x_ep_local: torch.Tensor,
w_ep_local: torch.Tensor,
b_ep_local: torch.Tensor | None,
expt_assignment,
expt_map_local: torch.Tensor,
expt_indx_flat: torch.Tensor,
combine_indx: torch.Tensor,
shmem,
ragged_metadata: RaggedTensorMetadata | None = None,
gemm_sms: int | None = None,
) -> torch.Tensor:
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expt_map_local is part of the public signature but is not used anywhere in this new module (the kernel relies on expt_assignment.expt_bitmask instead). If it’s intentionally unused for signature compatibility, rename it to _expt_map_local (or add a brief comment) to make that explicit; otherwise, either use it or remove it from the signature.

Copilot uses AI. Check for mistakes.
from_rank (int): The rank ID from which to read the data.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.
hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint).
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since hint is implemented as a tl.constexpr, callers can’t pass an arbitrary runtime value from within a Triton kernel—only compile-time constants are valid. Updating the docstrings for hint (here and the similar ones on store/copy/get/put/atomic_*) to explicitly state it must be compile-time constant would prevent misuse and hard-to-debug compilation errors.

Suggested change
hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint).
hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). When called from within a Triton kernel, this argument must be a compile-time constant (`tl.constexpr`). Defaults to None (no hint).

Copilot uses AI. Check for mistakes.
@neoblizz
Copy link
Member Author

neoblizz commented Mar 3, 2026

@copilot for GPT-OSS sweep the block sizes, comm-sms (for wg-specialization), etc. to find the most optimal setup for larger bpe, would like to understand the optimal tiling and configs.

Copy link
Contributor

Copilot AI commented Mar 3, 2026

@neoblizz I've opened a new pull request, #414, to work on those changes. Once the pull request is ready, I'll request review from you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants