Skip to content

fix fa varlen seqlen#1031

Merged
helloyongyang merged 4 commits intomainfrom
yr/fix_fa_varlen_seqlen
Apr 21, 2026
Merged

fix fa varlen seqlen#1031
helloyongyang merged 4 commits intomainfrom
yr/fix_fa_varlen_seqlen

Conversation

@STwangyingrui
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the attention mechanism across multiple models to standardize the use of Flash Attention v2 and v3, primarily by removing the explicit attention_type parameter and simplifying tensor device placement. While the changes aim to clean up the codebase, several critical issues were identified: potential AttributeError crashes in flash_attn.py when sequence lengths are passed as integers, a device mismatch error in the Bagel model's inference script due to the removal of necessary device transfers, and incorrect import aliasing for Flash Attention v2 functions.

Comment on lines +68 to +75
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The is_cpu check and .to() call on max_seqlen_q and max_seqlen_kv will cause an AttributeError if these arguments are passed as integers, which is the case in several models (e.g., Wan, Ulysses). Furthermore, Flash Attention kernels expect these values as host-side integers, so moving them to the GPU device is unnecessary and potentially incorrect.

Suggested change
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])

Comment on lines +122 to +129
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in FlashAttn2Weight, the is_cpu check on max_seqlen_q and max_seqlen_kv will crash if they are integers. These parameters should remain as host-side integers for the Flash Attention kernel launch.

Suggested change
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])

Comment on lines +163 to +164
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Removing the .to(AI_DEVICE) call here will cause a device mismatch error. Unlike other models that use the attention wrappers in common/ops/attn/flash_attn.py (which handle device placement), this file calls the raw flash_attn_varlen_func directly. The raw function requires cu_seqlens to be on the same device as the input tensors.

Suggested change
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE)
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE)

Comment on lines +8 to +9
from flash_attn import flash_attn_func_v2
from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import statements for flash_attn_func_v2 and flash_attn_varlen_func_v2 appear to be missing the as keyword. Standard flash_attn package does not export these names directly; they should likely be aliased from the standard function names to maintain consistency with the v3 and v4 imports below.

Suggested change
from flash_attn import flash_attn_func_v2
from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2
from flash_attn import flash_attn_func as flash_attn_func_v2
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v2

@helloyongyang helloyongyang merged commit 229d51c into main Apr 21, 2026
2 checks passed
@helloyongyang helloyongyang deleted the yr/fix_fa_varlen_seqlen branch April 21, 2026 06:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants