diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e9e6436c1af4..0b098ee668de 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -179,6 +179,8 @@ class AttentionBackendName(str, Enum): _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" _FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub" + FLASH_SPLIT = "flash_split" + FLASH_HUB_SPLIT = "flash_hub_split" # `aiter` AITER = "aiter" @@ -192,6 +194,7 @@ class AttentionBackendName(str, Enum): _NATIVE_MATH = "_native_math" _NATIVE_NPU = "_native_npu" _NATIVE_XLA = "_native_xla" + NATIVE_SPLIT = "native_split" # `sageattention` SAGE = "sage" @@ -433,7 +436,7 @@ def _check_shape( def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: - if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_SPLIT, AttentionBackendName.FLASH_VARLEN]: if not _CAN_USE_FLASH_ATTN: raise RuntimeError( f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." @@ -447,6 +450,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None elif backend in [ AttentionBackendName.FLASH_HUB, + AttentionBackendName.FLASH_HUB_SPLIT, AttentionBackendName.FLASH_VARLEN_HUB, AttentionBackendName._FLASH_3_HUB, AttentionBackendName._FLASH_3_VARLEN_HUB, @@ -514,7 +518,7 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask( cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) - max_seqlen_q = seqlens_q.max().item() + max_seqlen_q = seqlens_q.max().item() #TODO item() is inefficient and breaks torch.compile graphs. Use 'seq_len' parameter instead (see split attention backend) max_seqlen_k = seqlens_k.max().item() return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) @@ -1938,6 +1942,28 @@ def _flash_attention( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_SPLIT, + constraints=[_check_device, _check_shape], + supports_context_parallel=True, +) +def _flash_split_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_lens is passed + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + return __split_attention( + lambda q, k, v, mask: _flash_attention(q, k, v, mask, dropout_p, is_causal, scale, return_lse, _parallel_config), + query, key, value, attn_mask, seq_lens, + ) + @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_HUB, @@ -1975,6 +2001,29 @@ def _flash_attention_hub( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_HUB_SPLIT, + constraints=[_check_device, _check_shape], + supports_context_parallel=True, +) +def _flash_hub_split_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_lens is passed + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + return __split_attention( + lambda q, k, v, mask: _flash_attention_hub(q, k, v, mask, dropout_p, is_causal, scale, return_lse, _parallel_config), + query, key, value, attn_mask, seq_lens, + ) + + @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], @@ -2475,6 +2524,58 @@ def _native_attention( return out +def __split_attention( + attn_fn, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_len is passed - both must match +): + batch_size, batch_seq_len = query.shape[:2] + if seq_lens is None: + return attn_fn(query, key, value, attn_mask) + if all(sample_seq_len == batch_seq_len for sample_seq_len in seq_lens): + return attn_fn(query, key, value, None) + if any(sample_seq_len > batch_seq_len for sample_seq_len in seq_lens): + raise ValueError("Attention sequence lengths cannot be longer than maximum sequence length") + if len(seq_lens) != batch_size: + raise ValueError("Attention sequence lengths must match the batch size") + + result = [] + for index, sample_seq_len in enumerate(seq_lens): + sliced_query = query[index, :sample_seq_len, :, :].unsqueeze(0) + sliced_key = key [index, :sample_seq_len, :, :].unsqueeze(0) + sliced_value = value[index, :sample_seq_len, :, :].unsqueeze(0) + sliced_result = attn_fn(sliced_query, sliced_key, sliced_value, None) + + padding = torch.zeros((1, batch_seq_len - sample_seq_len) + sliced_result.shape[2:], device=sliced_result.device, dtype=sliced_result.dtype) + padded_result = torch.cat([sliced_result, padding], dim=1) + result.append(padded_result) + return torch.cat(result, dim=0) + +@_AttentionBackendRegistry.register( + AttentionBackendName.NATIVE_SPLIT, + constraints=[_check_device, _check_shape], + supports_context_parallel=True, +) +def _native_split_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + seq_lens: Optional[list[int]] = None, #attn_mask is ignored if seq_lens is passed + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + return __split_attention( + lambda q, k, v, mask: _native_attention(q, k, v, mask, dropout_p, is_causal, scale, enable_gqa, return_lse, _parallel_config), + query, key, value, attn_mask, seq_lens, + ) @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_CUDNN, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 8b46d163c3ea..93aab839c5e0 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -150,7 +150,7 @@ def compute_text_seq_len_from_mask( """ batch_size, text_seq_len = encoder_hidden_states.shape[:2] if encoder_hidden_states_mask is None: - return text_seq_len, None, None + return text_seq_len, [text_seq_len] * batch_size, None if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): raise ValueError( @@ -165,7 +165,7 @@ def compute_text_seq_len_from_mask( active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) has_active = encoder_hidden_states_mask.any(dim=1) per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) - return text_seq_len, per_sample_len, encoder_hidden_states_mask + return text_seq_len, per_sample_len.tolist(), encoder_hidden_states_mask class QwenTimestepProjEmbeddings(nn.Module): @@ -491,13 +491,12 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, # Text stream encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: torch.FloatTensor | None = None, + seq_lens: list[int] | None = None, image_rotary_emb: torch.Tensor | None = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") - seq_txt = encoder_hidden_states.shape[1] - # Compute QKV for image stream (sample projections) img_query = attn.to_q(hidden_states) img_key = attn.to_k(hidden_states) @@ -537,9 +536,9 @@ def __call__( # Concatenate for joint attention # Order: [text, image] - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) joint_hidden_states = dispatch_attention_fn( joint_query, @@ -550,6 +549,9 @@ def __call__( is_causal=False, backend=self._attention_backend, parallel_config=self._parallel_config, + attention_kwargs={ + 'seq_lens': seq_lens, + }, ) # Reshape back @@ -557,8 +559,9 @@ def __call__( joint_hidden_states = joint_hidden_states.to(joint_query.dtype) # Split attention outputs back - txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part - img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + image_seq_len = hidden_states.shape[1] + img_attn_output = joint_hidden_states[:, :image_seq_len, :] # Image part + txt_attn_output = joint_hidden_states[:, image_seq_len:, :] # Text part # Apply output projections img_attn_output = attn.to_out[0](img_attn_output.contiguous()) @@ -920,7 +923,7 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + text_seq_len, text_seq_len_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) @@ -938,11 +941,12 @@ def forward( # Construct joint attention mask once to avoid reconstructing in every block # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + batch_size, image_seq_len = hidden_states.shape[:2] + block_attention_kwargs["seq_lens"] = [text_seq_len + image_seq_len for text_seq_len in text_seq_len_per_sample] if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] - batch_size, image_seq_len = hidden_states.shape[:2] + # Build joint mask: [all_ones_for_image, text_mask] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + joint_attention_mask = torch.cat([image_mask, encoder_hidden_states_mask], dim=1) block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 1715aa4d4250..0c2f0b54746e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -468,6 +468,7 @@ def __call__( callback_on_step_end: Callable[[int, int], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, + batch_negative: bool = False, #TODO remove, only for testing ): r""" Function invoked when calling the pipeline for generation. @@ -598,23 +599,35 @@ def __call__( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt - prompt_embeds, prompt_embeds_mask = self.encode_prompt( - prompt=prompt, - prompt_embeds=prompt_embeds, - prompt_embeds_mask=prompt_embeds_mask, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - ) - if do_true_cfg: - negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( - prompt=negative_prompt, - prompt_embeds=negative_prompt_embeds, - prompt_embeds_mask=negative_prompt_embeds_mask, + if do_true_cfg and batch_negative: + combined_prompt_embeds, combined_prompt_embeds_mask = self.encode_prompt( + prompt=[prompt, negative_prompt], +# prompt_embeds=prompt_embeds, +# prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + dtype = combined_prompt_embeds.dtype + else: + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) + dtype = prompt_embeds.dtype + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -623,7 +636,7 @@ def __call__( num_channels_latents, height, width, - prompt_embeds.dtype, + dtype, device, generator, latents, @@ -677,31 +690,50 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): + if do_true_cfg and batch_negative: noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, + hidden_states=torch.cat([latents] * 2, dim=0), + timestep=torch.cat([timestep] * 2, dim=0) / 1000, + guidance=torch.cat([guidance] * 2, dim=0) if guidance is not None else None, + encoder_hidden_states_mask=combined_prompt_embeds_mask, + encoder_hidden_states=combined_prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] + noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0) + + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)