Sin2pi commited on
Commit
8edf5f3
·
verified ·
1 Parent(s): db65863

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -271,7 +271,7 @@ class rotary(nn.Module):
271
  theta = torch.tensor(theta, device=device, dtype=dtype)
272
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
273
  self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
274
- inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
275
  self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
276
 
277
  def return_f0(self, f0=None):
@@ -285,7 +285,7 @@ class rotary(nn.Module):
285
  def update_base(self, f0):
286
  f0 = self.return_f0()
287
  theta = f0.mean() + 1e-8
288
- inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
289
  self.inv_freq.data.copy_(inv_freq)
290
  self.theta.data.copy_(theta)
291
 
 
271
  theta = torch.tensor(theta, device=device, dtype=dtype)
272
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
273
  self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
274
+ inv_freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
275
  self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
276
 
277
  def return_f0(self, f0=None):
 
285
  def update_base(self, f0):
286
  f0 = self.return_f0()
287
  theta = f0.mean() + 1e-8
288
+ inv_freq = (theta / 200.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
289
  self.inv_freq.data.copy_(inv_freq)
290
  self.theta.data.copy_(theta)
291