cuDNN: add fused scaled-dot-product (flash) attention forward#3174
cuDNN: add fused scaled-dot-product (flash) attention forward#3174CarloLucibello wants to merge 4 commits into
Conversation
36ba1ff to
65ed835
Compare
Wrap cuDNN's modern fused SDPA via the backend graph API. The legacy `cudnnMultiHeadAttnForward` is the deprecated attention path; this adds the flash-attention kernel NVIDIA now recommends. - backend.jl: a thin typed layer over the cuDNN backend graph API (`cudnnBackend*`): descriptor wrapper, setattr!/getattr, and helpers to build an operation graph, run engine heuristics, finalize an execution plan, and execute a variant pack. No prior high-level wrapper used this API. - sdpa.jl: `cudnnSDPAForward[!]` driving the dedicated `CUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR`, with a per-shape execution-plan cache. Inputs are 4-D (head_dim, nheads, seq_len, batch) — NNlib's attention layout — so no permute is needed to interoperate. Scope: forward inference only, Float16/BFloat16 (cuDNN's fused engine does not support Float32/Float64). Verified vs a Float32 reference (relerr ~5e-4). Not yet supported, both blocked on cuDNN <= 9.20 and documented inline: - causal masking: the SDPA score-modifier subgraph needs cuDNN >= 9.21 (no CUDNN_jll yet); block-mask is block-sparse; the primitive matmul->softmax ->matmul graph yields no fused engine from raw backend calls. - backward: `cudnnSDPABackward` is a documented placeholder; the dedicated SDPA_BWD descriptor does not finalize on 9.20 and the supported path needs the same 9.21 subgraph mechanism (forward stats output already verified to work). Tests cover the forward against a dense reference (Float16, several shapes, custom scale) plus in-place agreement; gated on compute capability >= 8.0. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Constrain cudnnSDPAForward[!] inputs to DenseCuArray{T,4}: the cached plan
bakes in dense column-major strides, so non-contiguous views (or host Arrays,
whose Ptr/CuPtr distinction the variant pack erases) previously passed all
size checks and produced silently wrong results or illegal accesses.
- Key the plan cache on the current context: execution plans are finalized
against a specific device's handle, so a plan built on one GPU must not be
executed on another.
- Only swallow the CUDNN_STATUS_NOT_SUPPORTED family (3000s) in
try_execution_plan; BAD_PARAM/INTERNAL_ERROR now propagate instead of being
misreported as "no supported engine". The terminal error also distinguishes
"heuristic returned no configs" from "N configs failed to finalize".
- Build plans outside the cache lock (matching the descriptors.jl pattern) so
concurrent calls don't serialize behind a multi-millisecond plan build; a
racing duplicate build is benign and resolved by get! on insert.
- Use with_workspace for the execute workspace, freeing it eagerly instead of
leaving it to GC (house style, cf. convolution.jl/reduce.jl).
- Test that views, host arrays, and Float32 inputs are rejected.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Fix a GC-finalizer/cuDNN-JIT deadlock that caused BFloat16 fused SDPA to hang indefinitely on runtime-compiled engines (e.g. Blackwell sm_120): - getattr_descriptors: use raw handles internally, destroy unused slots synchronously before any long operation runs, register finalizers only for the n returned descriptors. Previously, up to 12 unreachable wrapper objects with cudnnBackendDestroyDescriptor finalizers were left pending; since all backend calls use @gcsafe_ccall, the GC could fire those finalizers during a JIT-compiling cudnnBackendExecute, contending for cuDNN's internal state lock -> deadlock. - try_execution_plan: destroy the descriptor immediately on failure instead of leaving a pending finalizer. - _build_sdpa_plan: keep ALL engine configs alive in SDPAPlan.keepalive (not just the first successful one), preventing the unused cfgs from becoming unreachable between plan build and execute. Also enable BFloat16 in the SDPA tests: switch to CUDA.randn for test data (GPU-side generation avoids the CPU BFloat16.(randn(Float32,...)) broadcast which triggers a slow first-time LLVM compilation on some machines), and add BFloat16 alongside Float16 in the test loop. Tested on RTX 5090 (sm_120), cuDNN 9.20: Float16 and BFloat16 both pass all 5 shape variants × 4 assertions = 40 tests, plus the 3 invalid-input rejection tests. Remove scratch investigation files (sdpa_bf16_findings.md, sdpa_bf16_hang.jl, torch_sdpa_test.py). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
65ed835 to
b75073c
Compare
|
runners are executing tests with missing cuDNN runtime |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3174 +/- ##
==========================================
+ Coverage 16.88% 17.16% +0.27%
==========================================
Files 124 127 +3
Lines 9886 10063 +177
==========================================
+ Hits 1669 1727 +58
- Misses 8217 8336 +119 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
That's because you're only running the Julia tests, not the CUDA ones. Those are skipped for draft PRs, and your PR was marked as such at the time of the last commit. |
There was a problem hiding this comment.
CUDA.jl Benchmarks
Details
| Benchmark suite | Current: 815104f | Previous: f36a98c | Ratio |
|---|---|---|---|
array/accumulate/Float32/1d |
99698 ns |
99535 ns |
1.00 |
array/accumulate/Float32/dims=1 |
75529 ns |
75442 ns |
1.00 |
array/accumulate/Float32/dims=1L |
1599802 ns |
1600645 ns |
1.00 |
array/accumulate/Float32/dims=2 |
141102 ns |
141158 ns |
1.00 |
array/accumulate/Float32/dims=2L |
661670 ns |
661466 ns |
1.00 |
array/accumulate/Int64/1d |
119131 ns |
118706 ns |
1.00 |
array/accumulate/Int64/dims=1 |
79739 ns |
79054 ns |
1.01 |
array/accumulate/Int64/dims=1L |
1717805 ns |
1717638 ns |
1.00 |
array/accumulate/Int64/dims=2 |
152843 ns |
152541 ns |
1.00 |
array/accumulate/Int64/dims=2L |
988478 ns |
988135 ns |
1.00 |
array/broadcast |
18477 ns |
18241 ns |
1.01 |
array/construct |
1187.6 ns |
1181 ns |
1.01 |
array/copy |
17049 ns |
16815 ns |
1.01 |
array/copyto!/cpu_to_gpu |
208690 ns |
206789 ns |
1.01 |
array/copyto!/gpu_to_cpu |
242106 ns |
240944 ns |
1.00 |
array/copyto!/gpu_to_gpu |
10466 ns |
10378 ns |
1.01 |
array/iteration/findall/bool |
135061 ns |
133949 ns |
1.01 |
array/iteration/findall/int |
148747 ns |
146873 ns |
1.01 |
array/iteration/findfirst/bool |
70244 ns |
69740 ns |
1.01 |
array/iteration/findfirst/int |
71877 ns |
70808 ns |
1.02 |
array/iteration/findmin/1d |
67797 ns |
63267 ns |
1.07 |
array/iteration/findmin/2d |
101331 ns |
100683 ns |
1.01 |
array/iteration/logical |
194581 ns |
191777 ns |
1.01 |
array/iteration/scalar |
65098 ns |
65423 ns |
1.00 |
array/permutedims/2d |
49617 ns |
49157 ns |
1.01 |
array/permutedims/3d |
50794 ns |
50803 ns |
1.00 |
array/permutedims/4d |
50483 ns |
50427 ns |
1.00 |
array/random/rand/Float32 |
12227 ns |
11313 ns |
1.08 |
array/random/rand/Int64 |
24230 ns |
21724 ns |
1.12 |
array/random/rand!/Float32 |
9531.666666666666 ns |
7865.333333333333 ns |
1.21 |
array/random/rand!/Int64 |
20843 ns |
18046 ns |
1.15 |
array/random/randn/Float32 |
35168 ns |
34325 ns |
1.02 |
array/random/randn!/Float32 |
26287 ns |
23663 ns |
1.11 |
array/reductions/mapreduce/Float32/1d |
33712 ns |
33244 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=1 |
38289 ns |
37924 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=1L |
51186 ns |
50709 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=2 |
55537 ns |
55197 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=2L |
67665 ns |
67275 ns |
1.01 |
array/reductions/mapreduce/Int64/1d |
40373 ns |
39856 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=1 |
41047 ns |
41151 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=1L |
88833 ns |
88798 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
57919 ns |
57992 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2L |
83854 ns |
83511 ns |
1.00 |
array/reductions/reduce/Float32/1d |
33711 ns |
33205 ns |
1.02 |
array/reductions/reduce/Float32/dims=1 |
38404 ns |
38128 ns |
1.01 |
array/reductions/reduce/Float32/dims=1L |
50972 ns |
50734 ns |
1.00 |
array/reductions/reduce/Float32/dims=2 |
55602 ns |
55518 ns |
1.00 |
array/reductions/reduce/Float32/dims=2L |
69417 ns |
68506 ns |
1.01 |
array/reductions/reduce/Int64/1d |
40371 ns |
39264 ns |
1.03 |
array/reductions/reduce/Int64/dims=1 |
40737 ns |
40684 ns |
1.00 |
array/reductions/reduce/Int64/dims=1L |
88802 ns |
88810 ns |
1.00 |
array/reductions/reduce/Int64/dims=2 |
57747 ns |
57487 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
83802 ns |
83461 ns |
1.00 |
array/reverse/1d |
16726 ns |
16785 ns |
1.00 |
array/reverse/1dL |
69546 ns |
69518 ns |
1.00 |
array/reverse/1dL_inplace |
67224 ns |
67215 ns |
1.00 |
array/reverse/1d_inplace |
8568.666666666666 ns |
8316.666666666666 ns |
1.03 |
array/reverse/2d |
20177 ns |
19784 ns |
1.02 |
array/reverse/2dL |
73646 ns |
73042 ns |
1.01 |
array/reverse/2dL_inplace |
67227 ns |
66906 ns |
1.00 |
array/reverse/2d_inplace |
9732 ns |
9756 ns |
1.00 |
array/sorting/1d |
2659829 ns |
2659326 ns |
1.00 |
array/sorting/2d |
1040382 ns |
1040275 ns |
1.00 |
array/sorting/by |
3195420 ns |
3194653 ns |
1.00 |
cuda/synchronization/context/auto |
1143.1 ns |
1126 ns |
1.02 |
cuda/synchronization/context/blocking |
925.4857142857143 ns |
913.1818181818181 ns |
1.01 |
cuda/synchronization/context/nonblocking |
6019.666666666667 ns |
6025.8 ns |
1.00 |
cuda/synchronization/stream/auto |
1024.8 ns |
1012.7 ns |
1.01 |
cuda/synchronization/stream/blocking |
838.0410958904109 ns |
812.4494382022472 ns |
1.03 |
cuda/synchronization/stream/nonblocking |
5878.8 ns |
5808.4 ns |
1.01 |
integration/byval/reference |
147712 ns |
147595 ns |
1.00 |
integration/byval/slices=1 |
149922 ns |
149686 ns |
1.00 |
integration/byval/slices=2 |
292820 ns |
292616 ns |
1.00 |
integration/byval/slices=3 |
435752 ns |
435427 ns |
1.00 |
integration/cudadevrt |
104699 ns |
104547 ns |
1.00 |
integration/volumerhs |
9302519 ns |
9302299 ns |
1.00 |
kernel/indexing |
12722 ns |
12659 ns |
1.00 |
kernel/indexing_checked |
13525 ns |
13481 ns |
1.00 |
kernel/launch |
2075.777777777778 ns |
2098.6666666666665 ns |
0.99 |
kernel/occupancy |
696.472602739726 ns |
728.3309352517986 ns |
0.96 |
kernel/rand |
13853 ns |
16517 ns |
0.84 |
latency/import |
3925156396 ns |
3918708441 ns |
1.00 |
latency/precompile |
4669975980 ns |
4679357936 ns |
1.00 |
latency/ttfp |
5375645004 ns |
4575800023 ns |
1.17 |
This comment was automatically generated by workflow using github-action-benchmark.
| # maxcount-n slots: if a finalizer fires cudnnBackendDestroyDescriptor while another thread | ||
| # is inside a cudnnBackendExecute that holds cuDNN's JIT lock, we get a deadlock on bf16 | ||
| # and other runtime-compiled engines. Instead we use raw handles, destroy the unused ones | ||
| # synchronously right here, and only register finalizers for the n handles we return. |
There was a problem hiding this comment.
@maleadt is this manual memory management ok?
|
this is ready for review/merge |
Adds a fused scaled-dot-product (flash) attention forward pass to the cuDNN extension, built on cuDNN's modern backend graph API.
What this adds
lib/cudnn/src/backend.jl— a thin Julia wrapper over the cuDNN backend graph API (cudnnBackend*): typed descriptor wrapper,setattr!/getattr,bfinalize!, and helpers for tensors, operation graphs, engine heuristics, execution plans, and variant packs.lib/cudnn/src/sdpa.jl—cudnnSDPAForward[!]usingCUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR(the dedicated single-node fused attention op):Input/output layout is
(head_dim, num_heads, seq_len, batch)— matches NNlib's internal multi-head attention layout directly, no permute needed. Plans are cached per(device, dtype, d, sq, skv, h, b).Supported / not yet supported
Float32/Float64 are rejected with a clear error (not supported by the fused engine).
BFloat16 fix
An early version had a hang on BFloat16 caused by GC finalizers (
cudnnBackendDestroyDescriptor) firing inside a@gcsafe_ccalltocudnnBackendExecuteduring cuDNN's JIT compilation of runtime-compiled engines, contending for cuDNN's internal state lock. Fixed by destroying unused descriptors synchronously ingetattr_descriptorsandtry_execution_planrather than leaving GC-pending finalizers. Both Float16 and BFloat16 are now tested.Tested on
RTX 5090 (sm_120), cuDNN 9.20, Julia 1.12.