Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
b547fcf
Fix QwenImage txt_seq_lens handling
kashif Nov 23, 2025
72a80c6
formatting
kashif Nov 23, 2025
88cee8b
formatting
kashif Nov 23, 2025
ac5ac24
remove txt_seq_lens and use bool mask
kashif Nov 29, 2025
0477526
Merge branch 'main' into txt_seq_lens
kashif Nov 29, 2025
18efdde
use compute_text_seq_len_from_mask
kashif Nov 30, 2025
6a549d4
add seq_lens to dispatch_attention_fn
kashif Nov 30, 2025
2d424e0
use joint_seq_lens
kashif Nov 30, 2025
30b5f98
remove unused index_block
kashif Nov 30, 2025
588dc04
Merge branch 'main' into txt_seq_lens
kashif Dec 6, 2025
f1c2d99
WIP: Remove seq_lens parameter and use mask-based approach
kashif Dec 6, 2025
ec52417
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 6, 2025
beeb020
fix formatting
kashif Dec 7, 2025
5c6f8e3
undo sage changes
kashif Dec 7, 2025
5d434f6
xformers support
kashif Dec 7, 2025
71ba603
hub fix
kashif Dec 8, 2025
babf490
Merge branch 'main' into txt_seq_lens
kashif Dec 8, 2025
afad335
fix torch compile issues
kashif Dec 8, 2025
2d5ab16
Merge branch 'main' into txt_seq_lens
sayakpaul Dec 9, 2025
c78a1e9
fix tests
kashif Dec 9, 2025
d6d4b1d
use _prepare_attn_mask_native
kashif Dec 9, 2025
e999b76
proper deprecation notice
kashif Dec 9, 2025
8115f0b
add deprecate to txt_seq_lens
kashif Dec 9, 2025
3b1510c
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
3676d8e
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
9ed0ffd
Only create the mask if there's actual padding
kashif Dec 10, 2025
abec461
Merge branch 'main' into txt_seq_lens
kashif Dec 10, 2025
e26e7b3
fix order of docstrings
kashif Dec 10, 2025
59e3882
Adds performance benchmarks and optimization details for QwenImage
cdutr Dec 11, 2025
0cb2138
Merge branch 'main' into txt_seq_lens
kashif Dec 12, 2025
60bd454
rope_text_seq_len = text_seq_len
kashif Dec 12, 2025
a5abbb8
rename to max_txt_seq_len
kashif Dec 12, 2025
8415c57
Merge branch 'main' into txt_seq_lens
kashif Dec 15, 2025
afff5b7
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
8dc6c3f
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
22cb03d
removed deprecated args
kashif Dec 17, 2025
125a3a4
undo unrelated change
kashif Dec 17, 2025
b5b6342
Updates QwenImage performance documentation
cdutr Dec 17, 2025
61f5265
Updates deprecation warnings for txt_seq_lens parameter
cdutr Dec 17, 2025
2ef38e2
fix compile
kashif Dec 17, 2025
270c63f
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 17, 2025
35efa06
formatting
kashif Dec 17, 2025
50c4815
fix compile tests
kashif Dec 17, 2025
c88bc06
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
1433783
rename helper
kashif Dec 17, 2025
8de799c
remove duplicate
kashif Dec 17, 2025
fc93747
smaller values
kashif Dec 18, 2025
8bb47d8
Merge branch 'main' into txt_seq_lens
kashif Dec 19, 2025
b7c288a
removed
kashif Dec 20, 2025
4700b7f
Merge branch 'main' into txt_seq_lens
kashif Dec 20, 2025
2f86879
split attention
dxqb Dec 21, 2025
87bbde4
fix type hints
dxqb Dec 21, 2025
66056f1
fix error if no attn mask is passed
dxqb Dec 21, 2025
0a713d1
Merge branch 'main' into split_attention
dxqb Dec 23, 2025
b9880f6
Merge remote-tracking branch 'origin/main' into pr-12702-base
dxqb Dec 23, 2025
a8bba06
Merge branch 'pr-12702-base' into split_attention
dxqb Dec 23, 2025
5eef3ef
check attention mask
dxqb Dec 26, 2025
e593603
Merge branch 'check_attn_mask' into split_attention
dxqb Dec 26, 2025
23e7a65
Merge branch 'main' into pr-12702-base
dxqb Dec 26, 2025
0584542
Merge branch 'pr-12702-base' into split_attention
dxqb Dec 26, 2025
7651363
more backends
dxqb Dec 26, 2025
cc134a7
bugfix
dxqb Dec 26, 2025
7e456cd
bugfix
dxqb Dec 26, 2025
c90289e
merge
dxqb Feb 13, 2026
f60e9cf
merge
dxqb Feb 13, 2026
5bf6698
merge
dxqb Feb 13, 2026
b38372c
fix: remove obsolete argument
dxqb Feb 13, 2026
b40acf2
add checks
dxqb Feb 13, 2026
3957728
fix type
dxqb Feb 13, 2026
a52a5c9
fix type hint
dxqb Feb 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class AttentionBackendName(str, Enum):
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
_FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub"
FLASH_SPLIT = "flash_split"
FLASH_HUB_SPLIT = "flash_hub_split"

# `aiter`
AITER = "aiter"
Expand All @@ -192,6 +194,7 @@ class AttentionBackendName(str, Enum):
_NATIVE_MATH = "_native_math"
_NATIVE_NPU = "_native_npu"
_NATIVE_XLA = "_native_xla"
NATIVE_SPLIT = "native_split"

# `sageattention`
SAGE = "sage"
Expand Down Expand Up @@ -433,7 +436,7 @@ def _check_shape(


def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_SPLIT, AttentionBackendName.FLASH_VARLEN]:
if not _CAN_USE_FLASH_ATTN:
raise RuntimeError(
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
Expand All @@ -447,6 +450,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None

elif backend in [
AttentionBackendName.FLASH_HUB,
AttentionBackendName.FLASH_HUB_SPLIT,
AttentionBackendName.FLASH_VARLEN_HUB,
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
Expand Down Expand Up @@ -514,7 +518,7 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
max_seqlen_q = seqlens_q.max().item()
max_seqlen_q = seqlens_q.max().item() #TODO item() is inefficient and breaks torch.compile graphs. Use 'seq_len' parameter instead (see split attention backend)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have benchmarked the new varlen attention in torch 2.10:
https://docs.pytorch.org/docs/stable/nn.attention.varlen.html

Performance isn't bad, but in my test case, split attention was always faster - if you considered the preparation of the varlen tokens. Even with torch.compiling the preparation.

max_seqlen_k = seqlens_k.max().item()
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)

Expand Down Expand Up @@ -1938,6 +1942,28 @@ def _flash_attention(

return (out, lse) if return_lse else out

@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_SPLIT,
constraints=[_check_device, _check_shape],
supports_context_parallel=True,
)
def _flash_split_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_lens is passed
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return __split_attention(
lambda q, k, v, mask: _flash_attention(q, k, v, mask, dropout_p, is_causal, scale, return_lse, _parallel_config),
query, key, value, attn_mask, seq_lens,
)


@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
Expand Down Expand Up @@ -1975,6 +2001,29 @@ def _flash_attention_hub(
return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB_SPLIT,
constraints=[_check_device, _check_shape],
supports_context_parallel=True,
)
def _flash_hub_split_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_lens is passed
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return __split_attention(
lambda q, k, v, mask: _flash_attention_hub(q, k, v, mask, dropout_p, is_causal, scale, return_lse, _parallel_config),
query, key, value, attn_mask, seq_lens,
)


@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down Expand Up @@ -2475,6 +2524,58 @@ def _native_attention(

return out

def __split_attention(
attn_fn,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_len is passed - both must match
):
batch_size, batch_seq_len = query.shape[:2]
if seq_lens is None:
return attn_fn(query, key, value, attn_mask)
if all(sample_seq_len == batch_seq_len for sample_seq_len in seq_lens):
return attn_fn(query, key, value, None)
if any(sample_seq_len > batch_seq_len for sample_seq_len in seq_lens):
raise ValueError("Attention sequence lengths cannot be longer than maximum sequence length")
if len(seq_lens) != batch_size:
raise ValueError("Attention sequence lengths must match the batch size")

result = []
for index, sample_seq_len in enumerate(seq_lens):
sliced_query = query[index, :sample_seq_len, :, :].unsqueeze(0)
sliced_key = key [index, :sample_seq_len, :, :].unsqueeze(0)
sliced_value = value[index, :sample_seq_len, :, :].unsqueeze(0)
sliced_result = attn_fn(sliced_query, sliced_key, sliced_value, None)

padding = torch.zeros((1, batch_seq_len - sample_seq_len) + sliced_result.shape[2:], device=sliced_result.device, dtype=sliced_result.dtype)
padded_result = torch.cat([sliced_result, padding], dim=1)
result.append(padded_result)
return torch.cat(result, dim=0)

@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE_SPLIT,
constraints=[_check_device, _check_shape],
supports_context_parallel=True,
)
def _native_split_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_lens is passed
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return __split_attention(
lambda q, k, v, mask: _native_attention(q, k, v, mask, dropout_p, is_causal, scale, enable_gqa, return_lse, _parallel_config),
query, key, value, attn_mask, seq_lens,
)

@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_CUDNN,
Expand Down
30 changes: 17 additions & 13 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def compute_text_seq_len_from_mask(
"""
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
if encoder_hidden_states_mask is None:
return text_seq_len, None, None
return text_seq_len, [text_seq_len] * batch_size, None

if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
raise ValueError(
Expand All @@ -165,7 +165,7 @@ def compute_text_seq_len_from_mask(
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
return text_seq_len, per_sample_len, encoder_hidden_states_mask
return text_seq_len, per_sample_len.tolist(), encoder_hidden_states_mask


class QwenTimestepProjEmbeddings(nn.Module):
Expand Down Expand Up @@ -491,13 +491,12 @@ def __call__(
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: torch.FloatTensor | None = None,
seq_lens: list[int] | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.FloatTensor:
if encoder_hidden_states is None:
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")

seq_txt = encoder_hidden_states.shape[1]

# Compute QKV for image stream (sample projections)
img_query = attn.to_q(hidden_states)
img_key = attn.to_k(hidden_states)
Expand Down Expand Up @@ -537,9 +536,9 @@ def __call__(

# Concatenate for joint attention
# Order: [text, image]
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = torch.cat([img_query, txt_query], dim=1)
joint_key = torch.cat([img_key, txt_key], dim=1)
joint_value = torch.cat([img_value, txt_value], dim=1)

joint_hidden_states = dispatch_attention_fn(
joint_query,
Expand All @@ -550,15 +549,19 @@ def __call__(
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
attention_kwargs={
'seq_lens': seq_lens,
},
)

# Reshape back
joint_hidden_states = joint_hidden_states.flatten(2, 3)
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)

# Split attention outputs back
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
image_seq_len = hidden_states.shape[1]
img_attn_output = joint_hidden_states[:, :image_seq_len, :] # Image part
txt_attn_output = joint_hidden_states[:, image_seq_len:, :] # Text part

# Apply output projections
img_attn_output = attn.to_out[0](img_attn_output.contiguous())
Expand Down Expand Up @@ -920,7 +923,7 @@ def forward(
encoder_hidden_states = self.txt_in(encoder_hidden_states)

# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
text_seq_len, text_seq_len_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)

Expand All @@ -938,11 +941,12 @@ def forward(
# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
batch_size, image_seq_len = hidden_states.shape[:2]
block_attention_kwargs["seq_lens"] = [text_seq_len + image_seq_len for text_seq_len in text_seq_len_per_sample]
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
# Build joint mask: [all_ones_for_image, text_mask]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = torch.cat([image_mask, encoder_hidden_states_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask

for index_block, block in enumerate(self.transformer_blocks):
Expand Down
84 changes: 58 additions & 26 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def __call__(
callback_on_step_end: Callable[[int, int], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
batch_negative: bool = False, #TODO remove, only for testing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all changes in this file are only for testing - should be reverted before merge

):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -598,23 +599,35 @@ def __call__(
)

do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
if do_true_cfg and batch_negative:
combined_prompt_embeds, combined_prompt_embeds_mask = self.encode_prompt(
prompt=[prompt, negative_prompt],
# prompt_embeds=prompt_embeds,
# prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
dtype = combined_prompt_embeds.dtype
else:
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
dtype = prompt_embeds.dtype
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
Expand All @@ -623,7 +636,7 @@ def __call__(
num_channels_latents,
height,
width,
prompt_embeds.dtype,
dtype,
device,
generator,
latents,
Expand Down Expand Up @@ -677,31 +690,50 @@ def __call__(
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
if do_true_cfg and batch_negative:
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
hidden_states=torch.cat([latents] * 2, dim=0),
timestep=torch.cat([timestep] * 2, dim=0) / 1000,
guidance=torch.cat([guidance] * 2, dim=0) if guidance is not None else None,
encoder_hidden_states_mask=combined_prompt_embeds_mask,
encoder_hidden_states=combined_prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)

comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
else:
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
Expand Down