Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the attention mechanism across multiple models to standardize the use of Flash Attention v2 and v3, primarily by removing the explicit attention_type parameter and simplifying tensor device placement. While the changes aim to clean up the codebase, several critical issues were identified: potential AttributeError crashes in flash_attn.py when sequence lengths are passed as integers, a device mismatch error in the Bagel model's inference script due to the removal of necessary device transfers, and incorrect import aliasing for Flash Attention v2 functions.
| if max_seqlen_q.is_cpu: | ||
| max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) | ||
| if max_seqlen_kv.is_cpu: | ||
| max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) | ||
| if len(q.shape) == 4: | ||
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | ||
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | ||
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) |
There was a problem hiding this comment.
The is_cpu check and .to() call on max_seqlen_q and max_seqlen_kv will cause an AttributeError if these arguments are passed as integers, which is the case in several models (e.g., Wan, Ulysses). Furthermore, Flash Attention kernels expect these values as host-side integers, so moving them to the GPU device is unnecessary and potentially incorrect.
| if max_seqlen_q.is_cpu: | |
| max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) | |
| if max_seqlen_kv.is_cpu: | |
| max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) | |
| if len(q.shape) == 4: | |
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | |
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | |
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) | |
| if len(q.shape) == 4: | |
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | |
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | |
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) |
| if max_seqlen_q.is_cpu: | ||
| max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) | ||
| if max_seqlen_kv.is_cpu: | ||
| max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) | ||
| if len(q.shape) == 4: | ||
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | ||
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | ||
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) |
There was a problem hiding this comment.
Similar to the issue in FlashAttn2Weight, the is_cpu check on max_seqlen_q and max_seqlen_kv will crash if they are integers. These parameters should remain as host-side integers for the Flash Attention kernel launch.
| if max_seqlen_q.is_cpu: | |
| max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) | |
| if max_seqlen_kv.is_cpu: | |
| max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) | |
| if len(q.shape) == 4: | |
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | |
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | |
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) | |
| if len(q.shape) == 4: | |
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | |
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | |
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) |
| cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) | ||
| cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) |
There was a problem hiding this comment.
Removing the .to(AI_DEVICE) call here will cause a device mismatch error. Unlike other models that use the attention wrappers in common/ops/attn/flash_attn.py (which handle device placement), this file calls the raw flash_attn_varlen_func directly. The raw function requires cu_seqlens to be on the same device as the input tensors.
| cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) | |
| cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) | |
| cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE) | |
| cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE) |
| from flash_attn import flash_attn_func_v2 | ||
| from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2 |
There was a problem hiding this comment.
The import statements for flash_attn_func_v2 and flash_attn_varlen_func_v2 appear to be missing the as keyword. Standard flash_attn package does not export these names directly; they should likely be aliased from the standard function names to maintain consistency with the v3 and v4 imports below.
| from flash_attn import flash_attn_func_v2 | |
| from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2 | |
| from flash_attn import flash_attn_func as flash_attn_func_v2 | |
| from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v2 |
No description provided.