Update model.py
Browse files
model.py
CHANGED
@@ -196,24 +196,30 @@ class rotary(nn.Module):
|
|
196 |
t = x.float().to(self.inv_freq.device)
|
197 |
if f0 is not None:
|
198 |
f0_mean = f0.mean()
|
199 |
-
f0_theta =
|
200 |
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
201 |
else:
|
202 |
inv_freq = self.inv_freq
|
203 |
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
204 |
freqs = freqs.float()
|
205 |
if self.variable_radius:
|
206 |
-
|
207 |
if f0 is not None:
|
208 |
f0 = f0[0]
|
209 |
seq_len = x
|
210 |
-
f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else device)
|
211 |
f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
freqs = torch.polar(radius, freqs)
|
214 |
else:
|
215 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
216 |
-
|
|
|
217 |
|
218 |
if "rotary" in self.debug:
|
219 |
if f0 is not None:
|
@@ -258,7 +264,7 @@ class rotary(nn.Module):
|
|
258 |
x1 = x1 * freqs
|
259 |
x1 = torch.view_as_real(x1).flatten(-2)
|
260 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
261 |
-
|
262 |
class SliceAttention(nn.Module):
|
263 |
def __init__(self, dims, heads, dropout=0.0):
|
264 |
super().__init__()
|
|
|
196 |
t = x.float().to(self.inv_freq.device)
|
197 |
if f0 is not None:
|
198 |
f0_mean = f0.mean()
|
199 |
+
f0_theta = (f0_mean**2) * self.pitch_scale
|
200 |
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
201 |
else:
|
202 |
inv_freq = self.inv_freq
|
203 |
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
204 |
freqs = freqs.float()
|
205 |
if self.variable_radius:
|
|
|
206 |
if f0 is not None:
|
207 |
f0 = f0[0]
|
208 |
seq_len = x
|
209 |
+
f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else self.device)
|
210 |
f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
|
211 |
+
max_f0 = torch.max(f0)
|
212 |
+
|
213 |
+
if max_f0 > 0:
|
214 |
+
radius = f0 / max_f0
|
215 |
+
else:
|
216 |
+
radius = torch.ones_like(f0)
|
217 |
+
|
218 |
freqs = torch.polar(radius, freqs)
|
219 |
else:
|
220 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
221 |
+
freqs = freqs.unsqueeze(0)
|
222 |
+
print(f"radius, {radius}")
|
223 |
|
224 |
if "rotary" in self.debug:
|
225 |
if f0 is not None:
|
|
|
264 |
x1 = x1 * freqs
|
265 |
x1 = torch.view_as_real(x1).flatten(-2)
|
266 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
267 |
+
|
268 |
class SliceAttention(nn.Module):
|
269 |
def __init__(self, dims, heads, dropout=0.0):
|
270 |
super().__init__()
|