Skip to content

[CK][CK_TILE] Fix FMHA codegen group mode dispatcher for fwd & bwd #3730

Open
arthurliu1998 wants to merge 2 commits intodevelopfrom
users/ArthurLiu/ck_fmha_codegen
Open

[CK][CK_TILE] Fix FMHA codegen group mode dispatcher for fwd & bwd #3730
arthurliu1998 wants to merge 2 commits intodevelopfrom
users/ArthurLiu/ck_fmha_codegen

Conversation

@arthurliu1998
Copy link
Copy Markdown

Motivation

FMHA codegen had incorrect dispatch behavior in group mode. Two root causes:

  1. Wrong field names in dispatch conditions — Used batch-mode fields (seqlen_q, seqlen_k) instead of group-mode fields (max_seqlen_q, max_seqlen_k), causing wrong kernel selection at runtime on gfx950.
  2. Missing kernel variants — Group mode was overly filtered out from smaller-tile specializations (bwd) and lacked spatial-padding pipeline variants on gfx950 (fwd).

gfx942 don't support trload pipeline.

Technical Details

fmha_bwd.py:

  • max_seq_q_cond and extra_cond now emit t.max_seqlen_q / t.max_seqlen_k for group mode.
  • Relaxed kernel filtering: group mode no longer skips tiles with max_seq_q != 0.

fmha_fwd.py:

  • get_bm0_cond emits a.max_seqlen_q for group mode tile-size dispatch.
  • Added two qr_async_trload pipeline variants with spatial padding for gfx950 group mode.

Test Plan

Triggering AITER CI job:

Test Result

Submission Checklist

  • Look over the contributing guidelines at

https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant