Skip to content

cuDNN: add fused scaled-dot-product (flash) attention forward#3174

Open
CarloLucibello wants to merge 4 commits into
JuliaGPU:mainfrom
CarloLucibello:cl/cudnn-sdpa
Open

cuDNN: add fused scaled-dot-product (flash) attention forward#3174
CarloLucibello wants to merge 4 commits into
JuliaGPU:mainfrom
CarloLucibello:cl/cudnn-sdpa

Conversation

@CarloLucibello

@CarloLucibello CarloLucibello commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Adds a fused scaled-dot-product (flash) attention forward pass to the cuDNN extension, built on cuDNN's modern backend graph API.

What this adds

lib/cudnn/src/backend.jl — a thin Julia wrapper over the cuDNN backend graph API (cudnnBackend*): typed descriptor wrapper, setattr!/getattr, bfinalize!, and helpers for tensors, operation graphs, engine heuristics, execution plans, and variant packs.

lib/cudnn/src/sdpa.jlcudnnSDPAForward[!] using CUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR (the dedicated single-node fused attention op):

cudnnSDPAForward(q, k, v; scale=1/√d)   # → out
cudnnSDPAForward!(out, q, k, v; scale)

Input/output layout is (head_dim, num_heads, seq_len, batch) — matches NNlib's internal multi-head attention layout directly, no permute needed. Plans are cached per (device, dtype, d, sq, skv, h, b).

Supported / not yet supported

Status
Float16, BFloat16
Ampere (sm_80) or newer
Forward inference
Causal masking blocked on cuDNN ≤ 9.20 (requires score-modifier subgraph, cuDNN ≥ 9.21)
Backward blocked on cuDNN ≤ 9.20 (same limitation)

Float32/Float64 are rejected with a clear error (not supported by the fused engine).

BFloat16 fix

An early version had a hang on BFloat16 caused by GC finalizers (cudnnBackendDestroyDescriptor) firing inside a @gcsafe_ccall to cudnnBackendExecute during cuDNN's JIT compilation of runtime-compiled engines, contending for cuDNN's internal state lock. Fixed by destroying unused descriptors synchronously in getattr_descriptors and try_execution_plan rather than leaving GC-pending finalizers. Both Float16 and BFloat16 are now tested.

Tested on

RTX 5090 (sm_120), cuDNN 9.20, Julia 1.12.

@CarloLucibello CarloLucibello marked this pull request as draft June 11, 2026 13:21
CarloLucibello and others added 3 commits June 16, 2026 17:52
Wrap cuDNN's modern fused SDPA via the backend graph API. The legacy
`cudnnMultiHeadAttnForward` is the deprecated attention path; this adds the
flash-attention kernel NVIDIA now recommends.

- backend.jl: a thin typed layer over the cuDNN backend graph API
  (`cudnnBackend*`): descriptor wrapper, setattr!/getattr, and helpers to build
  an operation graph, run engine heuristics, finalize an execution plan, and
  execute a variant pack. No prior high-level wrapper used this API.
- sdpa.jl: `cudnnSDPAForward[!]` driving the dedicated
  `CUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR`, with a per-shape execution-plan
  cache. Inputs are 4-D (head_dim, nheads, seq_len, batch) — NNlib's attention
  layout — so no permute is needed to interoperate.

Scope: forward inference only, Float16/BFloat16 (cuDNN's fused engine does not
support Float32/Float64). Verified vs a Float32 reference (relerr ~5e-4).

Not yet supported, both blocked on cuDNN <= 9.20 and documented inline:
- causal masking: the SDPA score-modifier subgraph needs cuDNN >= 9.21 (no
  CUDNN_jll yet); block-mask is block-sparse; the primitive matmul->softmax
  ->matmul graph yields no fused engine from raw backend calls.
- backward: `cudnnSDPABackward` is a documented placeholder; the dedicated
  SDPA_BWD descriptor does not finalize on 9.20 and the supported path needs the
  same 9.21 subgraph mechanism (forward stats output already verified to work).

Tests cover the forward against a dense reference (Float16, several shapes,
custom scale) plus in-place agreement; gated on compute capability >= 8.0.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Constrain cudnnSDPAForward[!] inputs to DenseCuArray{T,4}: the cached plan
  bakes in dense column-major strides, so non-contiguous views (or host Arrays,
  whose Ptr/CuPtr distinction the variant pack erases) previously passed all
  size checks and produced silently wrong results or illegal accesses.
- Key the plan cache on the current context: execution plans are finalized
  against a specific device's handle, so a plan built on one GPU must not be
  executed on another.
- Only swallow the CUDNN_STATUS_NOT_SUPPORTED family (3000s) in
  try_execution_plan; BAD_PARAM/INTERNAL_ERROR now propagate instead of being
  misreported as "no supported engine". The terminal error also distinguishes
  "heuristic returned no configs" from "N configs failed to finalize".
- Build plans outside the cache lock (matching the descriptors.jl pattern) so
  concurrent calls don't serialize behind a multi-millisecond plan build; a
  racing duplicate build is benign and resolved by get! on insert.
- Use with_workspace for the execute workspace, freeing it eagerly instead of
  leaving it to GC (house style, cf. convolution.jl/reduce.jl).
- Test that views, host arrays, and Float32 inputs are rejected.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Fix a GC-finalizer/cuDNN-JIT deadlock that caused BFloat16 fused SDPA to
hang indefinitely on runtime-compiled engines (e.g. Blackwell sm_120):

- getattr_descriptors: use raw handles internally, destroy unused slots
  synchronously before any long operation runs, register finalizers only
  for the n returned descriptors. Previously, up to 12 unreachable wrapper
  objects with cudnnBackendDestroyDescriptor finalizers were left pending;
  since all backend calls use @gcsafe_ccall, the GC could fire those
  finalizers during a JIT-compiling cudnnBackendExecute, contending for
  cuDNN's internal state lock -> deadlock.

- try_execution_plan: destroy the descriptor immediately on failure instead
  of leaving a pending finalizer.

- _build_sdpa_plan: keep ALL engine configs alive in SDPAPlan.keepalive
  (not just the first successful one), preventing the unused cfgs from
  becoming unreachable between plan build and execute.

Also enable BFloat16 in the SDPA tests: switch to CUDA.randn for test
data (GPU-side generation avoids the CPU BFloat16.(randn(Float32,...))
broadcast which triggers a slow first-time LLVM compilation on some
machines), and add BFloat16 alongside Float16 in the test loop.

Tested on RTX 5090 (sm_120), cuDNN 9.20: Float16 and BFloat16 both pass
all 5 shape variants × 4 assertions = 40 tests, plus the 3 invalid-input
rejection tests.

Remove scratch investigation files (sdpa_bf16_findings.md, sdpa_bf16_hang.jl,
torch_sdpa_test.py).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@CarloLucibello CarloLucibello marked this pull request as ready for review June 17, 2026 07:54
@CarloLucibello

Copy link
Copy Markdown
Contributor Author

@codecov

codecov Bot commented Jun 17, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 177 lines in your changes missing coverage. Please review.
✅ Project coverage is 17.16%. Comparing base (f36a98c) to head (815104f).

Files with missing lines Patch % Lines
lib/cudnn/src/backend.jl 0.00% 90 Missing ⚠️
lib/cudnn/src/sdpa.jl 0.00% 62 Missing ⚠️
lib/cudnn/test/sdpa.jl 0.00% 25 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3174      +/-   ##
==========================================
+ Coverage   16.88%   17.16%   +0.27%     
==========================================
  Files         124      127       +3     
  Lines        9886    10063     +177     
==========================================
+ Hits         1669     1727      +58     
- Misses       8217     8336     +119     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@maleadt

maleadt commented Jun 17, 2026

Copy link
Copy Markdown
Member

runners are executing tests with missing cuDNN runtime

That's because you're only running the Julia tests, not the CUDA ones. Those are skipped for draft PRs, and your PR was marked as such at the time of the last commit.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Details
Benchmark suite Current: 815104f Previous: f36a98c Ratio
array/accumulate/Float32/1d 99698 ns 99535 ns 1.00
array/accumulate/Float32/dims=1 75529 ns 75442 ns 1.00
array/accumulate/Float32/dims=1L 1599802 ns 1600645 ns 1.00
array/accumulate/Float32/dims=2 141102 ns 141158 ns 1.00
array/accumulate/Float32/dims=2L 661670 ns 661466 ns 1.00
array/accumulate/Int64/1d 119131 ns 118706 ns 1.00
array/accumulate/Int64/dims=1 79739 ns 79054 ns 1.01
array/accumulate/Int64/dims=1L 1717805 ns 1717638 ns 1.00
array/accumulate/Int64/dims=2 152843 ns 152541 ns 1.00
array/accumulate/Int64/dims=2L 988478 ns 988135 ns 1.00
array/broadcast 18477 ns 18241 ns 1.01
array/construct 1187.6 ns 1181 ns 1.01
array/copy 17049 ns 16815 ns 1.01
array/copyto!/cpu_to_gpu 208690 ns 206789 ns 1.01
array/copyto!/gpu_to_cpu 242106 ns 240944 ns 1.00
array/copyto!/gpu_to_gpu 10466 ns 10378 ns 1.01
array/iteration/findall/bool 135061 ns 133949 ns 1.01
array/iteration/findall/int 148747 ns 146873 ns 1.01
array/iteration/findfirst/bool 70244 ns 69740 ns 1.01
array/iteration/findfirst/int 71877 ns 70808 ns 1.02
array/iteration/findmin/1d 67797 ns 63267 ns 1.07
array/iteration/findmin/2d 101331 ns 100683 ns 1.01
array/iteration/logical 194581 ns 191777 ns 1.01
array/iteration/scalar 65098 ns 65423 ns 1.00
array/permutedims/2d 49617 ns 49157 ns 1.01
array/permutedims/3d 50794 ns 50803 ns 1.00
array/permutedims/4d 50483 ns 50427 ns 1.00
array/random/rand/Float32 12227 ns 11313 ns 1.08
array/random/rand/Int64 24230 ns 21724 ns 1.12
array/random/rand!/Float32 9531.666666666666 ns 7865.333333333333 ns 1.21
array/random/rand!/Int64 20843 ns 18046 ns 1.15
array/random/randn/Float32 35168 ns 34325 ns 1.02
array/random/randn!/Float32 26287 ns 23663 ns 1.11
array/reductions/mapreduce/Float32/1d 33712 ns 33244 ns 1.01
array/reductions/mapreduce/Float32/dims=1 38289 ns 37924 ns 1.01
array/reductions/mapreduce/Float32/dims=1L 51186 ns 50709 ns 1.01
array/reductions/mapreduce/Float32/dims=2 55537 ns 55197 ns 1.01
array/reductions/mapreduce/Float32/dims=2L 67665 ns 67275 ns 1.01
array/reductions/mapreduce/Int64/1d 40373 ns 39856 ns 1.01
array/reductions/mapreduce/Int64/dims=1 41047 ns 41151 ns 1.00
array/reductions/mapreduce/Int64/dims=1L 88833 ns 88798 ns 1.00
array/reductions/mapreduce/Int64/dims=2 57919 ns 57992 ns 1.00
array/reductions/mapreduce/Int64/dims=2L 83854 ns 83511 ns 1.00
array/reductions/reduce/Float32/1d 33711 ns 33205 ns 1.02
array/reductions/reduce/Float32/dims=1 38404 ns 38128 ns 1.01
array/reductions/reduce/Float32/dims=1L 50972 ns 50734 ns 1.00
array/reductions/reduce/Float32/dims=2 55602 ns 55518 ns 1.00
array/reductions/reduce/Float32/dims=2L 69417 ns 68506 ns 1.01
array/reductions/reduce/Int64/1d 40371 ns 39264 ns 1.03
array/reductions/reduce/Int64/dims=1 40737 ns 40684 ns 1.00
array/reductions/reduce/Int64/dims=1L 88802 ns 88810 ns 1.00
array/reductions/reduce/Int64/dims=2 57747 ns 57487 ns 1.00
array/reductions/reduce/Int64/dims=2L 83802 ns 83461 ns 1.00
array/reverse/1d 16726 ns 16785 ns 1.00
array/reverse/1dL 69546 ns 69518 ns 1.00
array/reverse/1dL_inplace 67224 ns 67215 ns 1.00
array/reverse/1d_inplace 8568.666666666666 ns 8316.666666666666 ns 1.03
array/reverse/2d 20177 ns 19784 ns 1.02
array/reverse/2dL 73646 ns 73042 ns 1.01
array/reverse/2dL_inplace 67227 ns 66906 ns 1.00
array/reverse/2d_inplace 9732 ns 9756 ns 1.00
array/sorting/1d 2659829 ns 2659326 ns 1.00
array/sorting/2d 1040382 ns 1040275 ns 1.00
array/sorting/by 3195420 ns 3194653 ns 1.00
cuda/synchronization/context/auto 1143.1 ns 1126 ns 1.02
cuda/synchronization/context/blocking 925.4857142857143 ns 913.1818181818181 ns 1.01
cuda/synchronization/context/nonblocking 6019.666666666667 ns 6025.8 ns 1.00
cuda/synchronization/stream/auto 1024.8 ns 1012.7 ns 1.01
cuda/synchronization/stream/blocking 838.0410958904109 ns 812.4494382022472 ns 1.03
cuda/synchronization/stream/nonblocking 5878.8 ns 5808.4 ns 1.01
integration/byval/reference 147712 ns 147595 ns 1.00
integration/byval/slices=1 149922 ns 149686 ns 1.00
integration/byval/slices=2 292820 ns 292616 ns 1.00
integration/byval/slices=3 435752 ns 435427 ns 1.00
integration/cudadevrt 104699 ns 104547 ns 1.00
integration/volumerhs 9302519 ns 9302299 ns 1.00
kernel/indexing 12722 ns 12659 ns 1.00
kernel/indexing_checked 13525 ns 13481 ns 1.00
kernel/launch 2075.777777777778 ns 2098.6666666666665 ns 0.99
kernel/occupancy 696.472602739726 ns 728.3309352517986 ns 0.96
kernel/rand 13853 ns 16517 ns 0.84
latency/import 3925156396 ns 3918708441 ns 1.00
latency/precompile 4669975980 ns 4679357936 ns 1.00
latency/ttfp 5375645004 ns 4575800023 ns 1.17

This comment was automatically generated by workflow using github-action-benchmark.

Comment thread lib/cudnn/src/backend.jl
# maxcount-n slots: if a finalizer fires cudnnBackendDestroyDescriptor while another thread
# is inside a cudnnBackendExecute that holds cuDNN's JIT lock, we get a deadlock on bf16
# and other runtime-compiled engines. Instead we use raw handles, destroy the unused ones
# synchronously right here, and only register finalizers for the n handles we return.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maleadt is this manual memory management ok?

@CarloLucibello

Copy link
Copy Markdown
Contributor Author

this is ready for review/merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants