Update model.py
Browse files
model.py
CHANGED
@@ -342,6 +342,9 @@ class rotary(nn.Module):
|
|
342 |
self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
343 |
freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
|
344 |
self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
|
|
|
|
|
|
|
345 |
|
346 |
def update_base(self, pitch):
|
347 |
theta = pitch.squeeze(0).to(device, dtype)
|
@@ -400,11 +403,12 @@ class rotary(nn.Module):
|
|
400 |
if f0 is not None:
|
401 |
f0_mean = f0.mean()
|
402 |
theta = f0_mean + 1e-8
|
403 |
-
freqs =
|
404 |
if "rotary1" in self.debug:
|
405 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
406 |
else:
|
407 |
-
freqs = self.
|
|
|
408 |
freqs = t[:, None] * freqs[None, :]
|
409 |
if self.radii:
|
410 |
if f0 is not None:
|
|
|
342 |
self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
343 |
freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
|
344 |
self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
|
345 |
+
|
346 |
+
freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
347 |
+
self.freqb = nn.Parameter(torch.tensor(freqb, device=device, dtype=dtype), requires_grad=True)
|
348 |
|
349 |
def update_base(self, pitch):
|
350 |
theta = pitch.squeeze(0).to(device, dtype)
|
|
|
403 |
if f0 is not None:
|
404 |
f0_mean = f0.mean()
|
405 |
theta = f0_mean + 1e-8
|
406 |
+
freqs = self.freqb
|
407 |
if "rotary1" in self.debug:
|
408 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
409 |
else:
|
410 |
+
freqs = self.freqb
|
411 |
+
|
412 |
freqs = t[:, None] * freqs[None, :]
|
413 |
if self.radii:
|
414 |
if f0 is not None:
|