Update model.py
Browse files
model.py
CHANGED
@@ -197,6 +197,7 @@ class rotary(nn.Module):
|
|
197 |
if f0 is not None:
|
198 |
f0_mean = f0.mean()
|
199 |
f0_theta = (f0_mean**2) * self.pitch_scale
|
|
|
200 |
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
201 |
else:
|
202 |
inv_freq = self.inv_freq
|
@@ -206,20 +207,13 @@ class rotary(nn.Module):
|
|
206 |
if f0 is not None:
|
207 |
f0 = f0[0]
|
208 |
seq_len = x
|
209 |
-
f0 = torch.tensor(f0, device=
|
210 |
f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
|
211 |
-
|
212 |
-
|
213 |
-
if max_f0 > 0:
|
214 |
-
radius = f0 / max_f0
|
215 |
-
else:
|
216 |
-
radius = torch.ones_like(f0)
|
217 |
-
|
218 |
freqs = torch.polar(radius, freqs)
|
219 |
else:
|
220 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
221 |
-
|
222 |
-
# print(f"radius, {radius}")
|
223 |
|
224 |
if "rotary" in self.debug:
|
225 |
if f0 is not None:
|
|
|
197 |
if f0 is not None:
|
198 |
f0_mean = f0.mean()
|
199 |
f0_theta = (f0_mean**2) * self.pitch_scale
|
200 |
+
#f0_theta = f0_mean * (f0_mean / self.theta) * self.theta * self.pitch_scale
|
201 |
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
202 |
else:
|
203 |
inv_freq = self.inv_freq
|
|
|
207 |
if f0 is not None:
|
208 |
f0 = f0[0]
|
209 |
seq_len = x
|
210 |
+
f0 = torch.tensor(f0, device=device if isinstance(x, torch.Tensor) else device)
|
211 |
f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
|
212 |
+
radius = 1.0 / (f0 + 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
freqs = torch.polar(radius, freqs)
|
214 |
else:
|
215 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
216 |
+
freqs = freqs.unsqueeze(0)
|
|
|
217 |
|
218 |
if "rotary" in self.debug:
|
219 |
if f0 is not None:
|