Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions stg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class STGFlag:
class PatchAttention(contextlib.AbstractContextManager):
def __init__(self, attn_idx: Optional[Union[int, List[int]]] = None):
self.current_idx = -1
self._guide_offset = 0

if isinstance(attn_idx, int):
self.attn_idx = [attn_idx]
Expand Down Expand Up @@ -151,19 +152,42 @@ def __exit__(self, exc_type, exc_value, traceback):
self.original_attention = None
self.original_attention_masked = None

def stg_attention(self, q, k, v, heads, *args, **kwargs):
self.current_idx += 1
if self.current_idx in self.attn_idx:
return v
def _stg_call(self, original, q, k, v, heads, args, kwargs):
# comfy's guide-mask self-attention (_attention_with_guide_mask in
# comfy/ldm/lightricks/model.py) splits one self-attention into several
# optimized_attention calls over contiguous *query slices*, each against
# the full key/value. Those sub-calls are the only ones that pass
# low_precision_attention=False, which lets us recognise them: a plain
# "return v" would be the wrong length (full sequence vs. the query
# slice) and would also miscount the STG attention index (one logical
# self-attention would consume several indices, shifting audio_attn_idx).
# We collapse the split into a single logical attention and, when
# skipping, return the matching slice of v.
guide_split = kwargs.get("low_precision_attention") is False and q.shape[1] < v.shape[1]
continuation = guide_split and self._guide_offset > 0

if not continuation:
self.current_idx += 1
skip = self.current_idx in self.attn_idx

if not guide_split:
return v if skip else original(q, k, v, heads, *args, **kwargs)

off = self._guide_offset
q_len = q.shape[1]
if skip:
out = v[:, off:off + q_len]
else:
return self.original_attention(q, k, v, heads, *args, **kwargs)
out = original(q, k, v, heads, *args, **kwargs)
off += q_len
self._guide_offset = 0 if off >= v.shape[1] else off
return out

def stg_attention(self, q, k, v, heads, *args, **kwargs):
return self._stg_call(self.original_attention, q, k, v, heads, args, kwargs)

def stg_attention_masked(self, q, k, v, heads, *args, **kwargs):
self.current_idx += 1
if self.current_idx in self.attn_idx:
return v
else:
return self.original_attention_masked(q, k, v, heads, *args, **kwargs)
return self._stg_call(self.original_attention_masked, q, k, v, heads, args, kwargs)


class STGBlockWrapper:
Expand Down