[release/3.4][Operator Mechanism] Support mm out_dtype for BF16 CUDA#79282
[release/3.4][Operator Mechanism] Support mm out_dtype for BF16 CUDA#79282risemeup1111 wants to merge 1 commit into
Conversation
…#79252) * [Operator Mechanism] Support mm out_dtype for BF16 CUDA Add a narrow CUDA BF16 x BF16 -> FP32 path for paddle.mm(out_dtype=paddle.float32), including schema, infermeta, stride dispatch, fused cuBLAS GEMM, and focused tests. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * [Operator Mechanism] Fix mm out_dtype review issues Use the canonical matmul path for static mm out_dtype handling and keep legacy compatibility attrs limited to supported types. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * [Operator Mechanism] Fix matmul out_dtype CI regressions Keep matmul compatible with unknown symbolic dimensions and legacy matmul_v2 to PIR translation when out_dtype is unset. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * [Operator Mechanism] Harden matmul out_dtype static path Preserve the legacy static mm path when out_dtype is unset and avoid rejecting unknown symbolic matmul dimensions during InferMeta. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * [Operator Mechanism] Fix mm out_dtype static BF16 test Allow the explicit static out_dtype path to pass BF16 variables through Python validation and feed BF16 static test data using the existing uint16 encoding helper. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * [Operator Mechanism] Fix matmul out_dtype PIR compat Add missing default/propagated out_dtype handling for legacy matmul translation, PIR serialization compatibility, and handwritten PIR/DRR matmul rewrites. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * [Operator Mechanism] Harden matmul out_dtype PIR fusions Avoid fusing explicit matmul out_dtype paths in PIR rewrite passes, document BF16 GEMM lda/ldb narrowing safety, and add a legacy matmul_v2 translator regression. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * [Operator Mechanism] Fix matmul out_dtype static compat Route static mm out_dtype through matmul_v2 so it reaches the phi matmul kernel, preserve user-provided out tensors, and let legacy matmul_v2 fusion pass compatibility accept only missing/default out_dtype while rejecting explicit output dtype paths. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * [Operator Mechanism] Prune mm out_dtype to dynamic path Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * [Operator Mechanism] Skip mm out_dtype tests on ROCm Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-06-09 15:10:11
📋 Review 摘要
PR 概述:为 paddle.mm 增加 CUDA BF16 输入、FP32 输出的 out_dtype 动态图路径。
变更范围:Phi InferMeta/YAML、GPU matmul kernel、cuBLAS BF16 GEMM 封装、Python paddle.mm API 和单测。
影响面 Tag:Operator Mechanism User Experience
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🔴 Bug | paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml:296 |
新 mm_out_dtype 公开路径没有 backward,会破坏 paddle.mm 在训练动态图中的梯度语义 |
📝 PR 规范检查
符合规范:标题包含算子机制 Tag,描述保留必填 section,并且 release cherry-pick 描述包含 devPR 链接。
总体评价
前向 CUDA BF16->FP32 路径、shape/dtype 限制和非连续输入处理都有对应实现和基础测试。但该分支接入了公开 paddle.mm,需要先补齐反向或显式 fail closed,否则用户在训练中使用新参数会得到不可用的动态图。
| kernel : | ||
| func : mm_out_dtype | ||
| data_type : x | ||
| traits : paddle::dialect::ForwardOnlyTrait |
There was a problem hiding this comment.
🔴 Bug mm_out_dtype 没有注册反向,却被 paddle.mm(..., out_dtype=...) 作为公开动态图路径直接调用。
同文件的普通 matmul 绑定了 matmul_grad,现有 test_mm_out.py 也验证了 paddle.mm(..., out=...) 可以反向;但这里仅声明 ForwardOnlyTrait,Python 侧没有在 stop_gradient=False 时拒绝,所以训练代码使用 BF16 mm(out_dtype=float32) 会在 backward 阶段丢梯度或报缺少 grad op。
建议修复方式:为该 op 补 backward : mm_out_dtype_grad(或复用/封装 matmul_grad,处理 out_grad 为 FP32、输入为 BF16 的梯度 dtype/cast 语义),并增加 stop_gradient=False 的 CUDA BF16 反向测试;如果只能支持推理,则在 Python 分支显式拒绝需要梯度的输入并在文档中标注 forward-only。
|
/re-run all-failed |
|
CI triage 结论:不建议对 #79282 继续 rerun。
该自动 cherry-pick 已由 #79285 手动替代并合入;当前 #79282 还与 |
PR Category
Operator Mechanism
PR Types
New features
Description
Temporary add a narrow CUDA BF16 x BF16 -> FP32 path for paddle.mm(out_dtype=paddle.float32), including schema, infermeta, stride dispatch, fused cuBLAS GEMM, and focused tests.
pcard-91067
是否引起精度变化
否
Cherry-pick of #79252 (authored by @A-nnonymous) to
release/3.4.devPR:#79252