Conversation
- Remove seq_lens parameter from dispatch_attention_fn - Update varlen backends to extract seqlens from masks - Update QwenImage to pass 2D joint_attention_mask - Fix native backend to handle 2D boolean masks - Fix sage_varlen seqlens_q to match seqlens_k for self-attention Note: sage_varlen still producing black images, needs further investigation
…to txt_seq_lens
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Enhances documentation with comprehensive performance insights for QwenImage pipeline:
|
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Yes. Next steps:
|
| callback_on_step_end: Callable[[int, int], None] | None = None, | ||
| callback_on_step_end_tensor_inputs: list[str] = ["latents"], | ||
| max_sequence_length: int = 512, | ||
| batch_negative: bool = False, #TODO remove, only for testing |
There was a problem hiding this comment.
all changes in this file are only for testing - should be reverted before merge
done |
| cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) | ||
| cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) | ||
| max_seqlen_q = seqlens_q.max().item() | ||
| max_seqlen_q = seqlens_q.max().item() #TODO item() is inefficient and breaks torch.compile graphs. Use 'seq_len' parameter instead (see split attention backend) |
There was a problem hiding this comment.
I have benchmarked the new varlen attention in torch 2.10:
https://docs.pytorch.org/docs/stable/nn.attention.varlen.html
Performance isn't bad, but in my test case, split attention was always faster - if you considered the preparation of the varlen tokens. Even with torch.compiling the preparation.
What does this PR do?
Most recent models have been using variable lengths captions (Qwen, Chroma, Z-Image, ...) and require attention masking if batch size > 1 with multiple captions.
torch SDPA only uses its internal flash attention algorithm if there is no attention mask. Otherwise it falls back to another algorithm that is significantly slower, especially for high sequence lengths.
This PR implements an attention backend that splits up the attention batch into individual samples. Even though attention has to be called multiple times then, it is still faster than masked attention (tested up to batch size 8).
This PR also lays the groundwork for efficiently using "flash varlen" and other varlen attention backends, which are already implemented but not efficiently (see code comment).
This PR is based on @kashif and @cdutr 's work in this PR: #12702
Benchmarks
Training benchmarks using OneTrainer: especially training in higher resolution benefits:

Inference benchmark using diffusers Qwen example script (but with regional compilation):

Inference benefits when comparing apples to apples, which is BS2 for CFG. However, the current pipeline already avoids attention masks by calling the transformer twice with BS1, so there is only a slight practical improvement for inference:
Who can review?
@yiyixuxu and @sayakpaul
CC @kashif and @cdutr for feedback
contains #12892