Conversation
There was a problem hiding this comment.
Pull request overview
This PR expands expert-sharded MoE “fusion modes” and introduces new fused Triton kernels, alongside Iris pointer-translation vectorization hints, to improve distributed MoE performance.
Changes:
- Add new fusion modes and kernels: DP→EP gather+GEMM fusion and a WG-specialized EP→DP GEMM+scatter fusion.
- Add optional
hintplumbing through Iris load/store/copy and apply vectorization hints in MoE kernels. - Extend benchmarks/tests to exercise additional fusion modes and expose
--gemm_smstuning.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/examples/test_expert_sharded_moe.py | Expands parametrized fusion-mode test matrix. |
| iris/iris.py | Adds hint constexpr to translation + load/store/copy/atomics to enable vectorization hints. |
| examples/31_expert_sharded_moe/moe.py | Adds new fusion modes, selection logic, and gemm_sms threading into WG fused path. |
| examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py | New WG-specialized persistent kernel overlapping GEMM and scatter via locks. |
| examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py | Adds Iris store hint + extra kernel launch meta-params. |
| examples/31_expert_sharded_moe/fused_dp_to_ep_matmul.py | New fused DP→EP gather + expert matmul kernel (remote iris.load prologue). |
| examples/31_expert_sharded_moe/dispatch.py | Adds vectorization hints for offsets and iris.store. |
| examples/31_expert_sharded_moe/combine.py | Adds vectorization hints for offsets and iris.store. |
| benchmark/examples/benchmark_moe.py | Adds WG fusion mode + --gemm_sms and enhances breakdown reporting. |
| def mode_name(self) -> str: | ||
| parts: list[str] = [] | ||
| if self.fuse_convert_dp_to_ep_grouped_matmul: | ||
| parts.append("convert_dp_to_ep_grouped_matmul") | ||
| if self.fuse_grouped_matmul_convert_ep_to_dp: | ||
| parts.append("grouped_matmul_convert_ep_to_dp") | ||
| if self.fuse_grouped_matmul_convert_ep_to_dp_wg: | ||
| parts.append("wg_grouped_matmul_convert_ep_to_dp") | ||
| if not parts: | ||
| return "unfused" | ||
| return "fused_" + "__".join(parts) |
There was a problem hiding this comment.
mode_name() produces "fused_wg_grouped_matmul_convert_ep_to_dp" for the WG mode, but from_mode_name() / CLI / tests use "wg_fused_grouped_matmul_convert_ep_to_dp". This breaks round-tripping and can confuse logs/config serialization; align the string naming (either change mode_name() output for the WG flag, or accept both spellings in from_mode_name()).
| fusion_config.fuse_grouped_matmul_convert_ep_to_dp_wg, | ||
| ] | ||
| ) | ||
| if n_fusions_active > 1: |
There was a problem hiding this comment.
This new restriction conflicts with the existing “combined” fusion mode name that from_mode_name() still constructs (fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp) and that the benchmark CLI still offers. Either remove/disable the combined mode in from_mode_name() and CLI choices, or extend the implementation to support the combined mode (and adjust this guard accordingly) so user-selected modes don’t fail at runtime.
| if n_fusions_active > 1: | |
| # Allow the documented "combined" fusion mode | |
| combined_mode = ( | |
| fusion_config.fuse_convert_dp_to_ep_grouped_matmul | |
| and fusion_config.fuse_grouped_matmul_convert_ep_to_dp | |
| and not fusion_config.fuse_grouped_matmul_convert_ep_to_dp_wg | |
| ) | |
| if n_fusions_active > 1 and not combined_mode: |
| "fused_grouped_matmul_convert_ep_to_dp", | ||
| "fused_convert_dp_to_ep_grouped_matmul", | ||
| "wg_fused_grouped_matmul_convert_ep_to_dp", | ||
| "fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp", |
There was a problem hiding this comment.
The CLI still advertises the combined fusion mode (fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp), but the MoE pipeline now raises if more than one fusion flag is enabled. Remove this choice (and any corresponding parsing) or make it a supported mode; otherwise benchmark runs can error immediately.
| "fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp", |
| num_sms = cu_count | ||
| if gemm_sms is None: | ||
| gemm_sms = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 |
There was a problem hiding this comment.
With the current default, gemm_sms equals cu_count when cu_count is a power-of-two, which makes COMM_SMS = NUM_SMS - GEMM_SMS become 0 in the kernel. That results in no communication programs running (and/or invalid loop strides), so outputs won’t be scattered. Clamp gemm_sms to [1, num_sms - 1] (or choose a safer default split, e.g. num_sms // 2) and error clearly if the user passes an invalid value.
| num_sms = cu_count | |
| if gemm_sms is None: | |
| gemm_sms = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 | |
| num_sms = max(int(cu_count), 1) | |
| # Split SMs between GEMM and communication paths. | |
| # Require at least 2 SMs so that we can run both roles. | |
| if num_sms < 2: | |
| raise RuntimeError( | |
| f"_wg_fused_exp_matmul_ep_to_dp_kernel requires at least 2 SMs, " | |
| f"but found {num_sms} on device {device}." | |
| ) | |
| if gemm_sms is None: | |
| # Use a safer default than the next power-of-two, which could equal num_sms. | |
| # Start from an even split and clamp to [1, num_sms - 1]. | |
| default_gemm_sms = max(1, num_sms // 2) | |
| gemm_sms = min(default_gemm_sms, num_sms - 1) | |
| else: | |
| # Validate user-provided gemm_sms. | |
| if not isinstance(gemm_sms, int): | |
| raise TypeError( | |
| f"gemm_sms must be an int in [1, num_sms - 1], got {type(gemm_sms)}." | |
| ) | |
| if gemm_sms < 1 or gemm_sms >= num_sms: | |
| raise ValueError( | |
| f"gemm_sms must be in [1, num_sms - 1] (num_sms={num_sms}), " | |
| f"but got gemm_sms={gemm_sms}." | |
| ) |
|
|
||
| tl.debug_barrier() | ||
| lock_idx = en_pair * max_m_tiles + pid_m | ||
| tl.store(locks_ptr + lock_idx, 1, cache_modifier=".wt") |
There was a problem hiding this comment.
The kernel uses a plain global tl.store to publish the lock, while other programs spin on tl.load(... volatile=True); this does not guarantee a proper release/acquire ordering between the y_buf write and lock visibility across CUs. Consider using an atomic publication for the lock (e.g., tl.atomic_xchg(..., sem="release")) and an acquire read on the consumer side (e.g., atomic read via tl.atomic_add(ptr, 0, sem="acquire")) to avoid reading stale y_buf data.
| tl.store(locks_ptr + lock_idx, 1, cache_modifier=".wt") | |
| tl.atomic_xchg(locks_ptr + lock_idx, 1, sem="release") |
| def wg_fused_exp_matmul_ep_to_dp( | ||
| x_ep_local: torch.Tensor, | ||
| w_ep_local: torch.Tensor, | ||
| b_ep_local: torch.Tensor | None, | ||
| expt_assignment, | ||
| expt_map_local: torch.Tensor, | ||
| expt_indx_flat: torch.Tensor, | ||
| combine_indx: torch.Tensor, | ||
| shmem, | ||
| ragged_metadata: RaggedTensorMetadata | None = None, | ||
| gemm_sms: int | None = None, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
expt_map_local is part of the public signature but is not used anywhere in this new module (the kernel relies on expt_assignment.expt_bitmask instead). If it’s intentionally unused for signature compatibility, rename it to _expt_map_local (or add a brief comment) to make that explicit; otherwise, either use it or remove it from the signature.
| from_rank (int): The rank ID from which to read the data. | ||
| heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. | ||
| mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. | ||
| hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). |
There was a problem hiding this comment.
Since hint is implemented as a tl.constexpr, callers can’t pass an arbitrary runtime value from within a Triton kernel—only compile-time constants are valid. Updating the docstrings for hint (here and the similar ones on store/copy/get/put/atomic_*) to explicitly state it must be compile-time constant would prevent misuse and hard-to-debug compilation errors.
| hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). | |
| hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). When called from within a Triton kernel, this argument must be a compile-time constant (`tl.constexpr`). Defaults to None (no hint). |
|
@copilot for GPT-OSS sweep the block sizes, comm-sms (for wg-specialization), etc. to find the most optimal setup for larger bpe, would like to understand the optimal tiling and configs. |
Motivation
... WIP, don't merge.
Submission Checklist