Spaces:
Paused
Paused
Update skyreels_v2_infer/modules/transformer.py
Browse files
skyreels_v2_infer/modules/transformer.py
CHANGED
|
@@ -13,7 +13,7 @@ from torch.nn.attention.flex_attention import BlockMask
|
|
| 13 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 14 |
from torch.nn.attention.flex_attention import flex_attention
|
| 15 |
|
| 16 |
-
from .attention import flash_attention
|
| 17 |
|
| 18 |
|
| 19 |
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
|
|
@@ -160,7 +160,7 @@ class WanSelfAttention(nn.Module):
|
|
| 160 |
if not self._flag_ar_attention:
|
| 161 |
q = rope_apply(q, grid_sizes, freqs)
|
| 162 |
k = rope_apply(k, grid_sizes, freqs)
|
| 163 |
-
x =
|
| 164 |
else:
|
| 165 |
q = rope_apply(q, grid_sizes, freqs)
|
| 166 |
k = rope_apply(k, grid_sizes, freqs)
|
|
@@ -199,7 +199,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|
| 199 |
v = self.v(context).view(b, -1, n, d)
|
| 200 |
|
| 201 |
# compute attention
|
| 202 |
-
x =
|
| 203 |
|
| 204 |
# output
|
| 205 |
x = x.flatten(2)
|
|
|
|
| 13 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 14 |
from torch.nn.attention.flex_attention import flex_attention
|
| 15 |
|
| 16 |
+
from .attention import flash_attention, attention
|
| 17 |
|
| 18 |
|
| 19 |
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
|
|
|
|
| 160 |
if not self._flag_ar_attention:
|
| 161 |
q = rope_apply(q, grid_sizes, freqs)
|
| 162 |
k = rope_apply(k, grid_sizes, freqs)
|
| 163 |
+
x = attention(q=q, k=k, v=v, window_size=self.window_size)
|
| 164 |
else:
|
| 165 |
q = rope_apply(q, grid_sizes, freqs)
|
| 166 |
k = rope_apply(k, grid_sizes, freqs)
|
|
|
|
| 199 |
v = self.v(context).view(b, -1, n, d)
|
| 200 |
|
| 201 |
# compute attention
|
| 202 |
+
x = attention(q, k, v)
|
| 203 |
|
| 204 |
# output
|
| 205 |
x = x.flatten(2)
|