Update model.py
Browse files
model.py
CHANGED
@@ -337,14 +337,14 @@ class rotary(nn.Module):
|
|
337 |
|
338 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
339 |
|
340 |
-
freqs = 1. / (theta ** (torch.arange(0, dim, 2, device=device, dtype=dtype)[:(dim // 2)].float() / dims))
|
341 |
-
self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
|
342 |
-
self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
343 |
-
freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
|
344 |
-
self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
|
345 |
|
346 |
freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
347 |
-
self.
|
348 |
|
349 |
def update_base(self, pitch):
|
350 |
theta = pitch.squeeze(0).to(device, dtype)
|
@@ -403,11 +403,11 @@ class rotary(nn.Module):
|
|
403 |
if f0 is not None:
|
404 |
f0_mean = f0.mean()
|
405 |
theta = f0_mean + 1e-8
|
406 |
-
freqs = self.
|
407 |
if "rotary1" in self.debug:
|
408 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
409 |
else:
|
410 |
-
freqs = self.
|
411 |
|
412 |
freqs = t[:, None] * freqs[None, :]
|
413 |
if self.radii:
|
|
|
337 |
|
338 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
339 |
|
340 |
+
# freqs = 1. / (theta ** (torch.arange(0, dim, 2, device=device, dtype=dtype)[:(dim // 2)].float() / dims))
|
341 |
+
# self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
|
342 |
+
# self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
343 |
+
# freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
|
344 |
+
# self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
|
345 |
|
346 |
freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
347 |
+
self.inv_freq = nn.Parameter(torch.tensor(freqb, device=device, dtype=dtype), requires_grad=True)
|
348 |
|
349 |
def update_base(self, pitch):
|
350 |
theta = pitch.squeeze(0).to(device, dtype)
|
|
|
403 |
if f0 is not None:
|
404 |
f0_mean = f0.mean()
|
405 |
theta = f0_mean + 1e-8
|
406 |
+
freqs = self.inv_freq
|
407 |
if "rotary1" in self.debug:
|
408 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
409 |
else:
|
410 |
+
freqs = self.inv_freq
|
411 |
|
412 |
freqs = t[:, None] * freqs[None, :]
|
413 |
if self.radii:
|