perf/accuracy: Flash Attention, torch-native SO(3), cosine schedule, DDIM, analytical g(t), acos fix#454
perf/accuracy: Flash Attention, torch-native SO(3), cosine schedule, DDIM, analytical g(t), acos fix#454mooreneural wants to merge 2 commits into
Conversation
…IM, analytical g(t) Attention (Attention_module.py): - Replace hand-rolled einsum attention with F.scaled_dot_product_attention in Attention, AttentionWithBias, and MSAColAttention. Uses Flash Attention automatically when available on CUDA (20-40% speedup, O(1) memory). - AttentionWithBias passes the pairwise bias as attn_mask so it is folded into the fused kernel rather than materializing a separate attention matrix. SO3 diffusion (igso3.py, diffusion.py, inference/utils.py): - Add hat_batch(), Log_torch(), Exp_torch() -- on-device rotation ops using the Rodrigues formula. Eliminates all scipy CPU round-trips during inference. - Replace scipy_R calls in reverse_sample_vectorized() and diffuse_frames() with the new torch-native equivalents (stay on GPU, no .cpu()/.numpy() transfers). - Remove redundant scipy rotation normalization in get_next_frames(); rotation matrices from rigid_from_3_points are already orthogonal. Noise schedule (diffusion.py): - Add cosine schedule (Nichol & Dhariwal, 2021). Enabled via schedule_type="cosine"; b0/bT are ignored for this mode. - Analytical g(t) for linear schedule: eliminates a per-step autograd call. Formula: g(t) = sqrt(2 * sigma(t) * (min_b + t*(max_b - min_b))). IGSO3 cache (diffusion.py): - Add module-level _igso3_cache dict. Avoids repeated disk deserialization when multiple Diffuser objects are created in the same process (batch inference). DDIM sampling (inference/utils.py): - Add get_mu_xt_x0_ddim() implementing the deterministic DDIM update rule. - Wire ddim=True flag through Denoise.__init__() -> get_next_pose() -> get_next_ca(). Setting ddim=True produces deterministic, lower-variance trajectories and enables fewer-step inference at equivalent quality. Numerical stability (kinematics.py): - Clamp input to acos in get_ang() to [-1, 1] to prevent NaN from float rounding at exactly +/-1.
|
Thanks for running the CI check. This failure is not introduced by my PR. The test environment resolves dgl==2.1.0 + torchdata==0.11.0, which are incompatible. DGL 2.1.0 internally imports torchdata.datapipes, but that submodule was removed from torchdata starting in version 0.7.0. The crash originates in util_module.py → import dgl, which is completely unrelated to the five files changed in this PR (Attention_module.py, diffusion.py, igso3.py, inference/utils.py, kinematics.py). The identical failure would occur on the unmodified upstream main branch under these package versions. To fix the CI environment (outside the scope of this PR): Option A: Pin torchdata to the last release that still ships datapipespip install "torchdata<0.7" Option B: Upgrade DGL to a version that no longer imports datapipespip install "dgl>=2.4" Happy to help investigate further if useful. |
The original Log_torch used theta/(2*sin(theta)) * skew throughout [0, pi]. Near theta=pi, the float32 R matrix loses trace precision (sin(theta) -> 0), causing the computed theta from acos(trace) to diverge from the theta encoded in the skew elements -- producing up to 10x rotation-matrix error in the worst case. Fix: for cos(theta) < 0, estimate theta via pi - asin(||skew||/2) instead. The skew magnitude 2*sin(theta) remains accurate in float32 even near pi, avoiding the trace instability entirely. Fall back to R+I decomposition (R+I = 2*outer(n,n)) only for the exact-pi case where skew -> 0. All arithmetic is done in float64 on-device; result is cast back to input dtype. Round-trip error R -> Log_torch -> Exp_torch -> R: Before fix: max|dR| = 9.83 near theta=pi (catastrophic) After fix: max|dR| = 2.25e-04 full range, mean = 1.58e-07 Also adds scripts/benchmark_pr454.py for measuring PR RosettaCommons#454 improvements.
|
Benchmark results (RTX 5080, PyTorch 2.11.0+cu128) Flash Attention (
Same pattern for SO(3) torch-native ops isolated timings show scipy faster for typical residue counts (GPU kernel launch overhead dominates for N=50–500). The value here is eliminating the Numerical fix (commit 2c21e73): patched a float32 precision issue in Round-trip error
|
Scientific and performance improvements across the diffusion pipeline. No breaking changes, all new features are opt-in via existing config flags or new parameters with safe defaults.
Changes
Flash Attention (
Attention_module.py)Attention,AttentionWithBias, andMSAColAttentionwithF.scaled_dot_product_attentionAttentionWithBiaspasses its pairwise bias asattn_mask, fusing it into the kernel rather than a separate addTorch-native SO(3) ops (
igso3.py,diffusion.py,inference/utils.py)hat_batch(),Log_torch(),Exp_torch()Rodrigues-formula rotation ops that stay on-devicescipy_R/.cpu().numpy()roundtrips that fired at every denoising step inreverse_sample_vectorized(),diffuse_frames(), andget_next_frames()Cosine noise schedule (
diffusion.py)schedule_type="cosine"(Nichol & Dhariwal, 2021 - Improved DDPM)diffuser.schedule_type=cosine;b0/bTare ignored for this modeAnalytical
g(t)(diffusion.py)torch.autograd.gradcall with the closed-form derivative:g(t) = sqrt(2·σ(t)·(min_b + t·(max_b − min_b)))IGSO3 module-level cache (
diffusion.py)_igso3_cachedict at module levelDiffuserinstances are created in the same process (e.g., batch inference scripts)DDIM deterministic sampling (
inference/utils.py)get_mu_xt_x0_ddim()implementing the DDIM update rule (Song et al., 2021)ddim=Trueflag throughDenoise.__init__()→get_next_pose()→get_next_ca()Falseno change to existing behavioracosNaN clamp (kinematics.py)torch.acos(vw)→torch.acos(torch.clamp(vw, -1.0, 1.0))Notes for reviewers
Denoiseclass level only - rotation (SO3) denoising still uses the stochastic IGSO3 reverse SDE; a full DDIM-on-SO3 implementation would be a follow-onLog()andExp()functions inigso3.pyare kept for backward compatibility (they're used in offline precomputation, not inference)