Sin2pi commited on
Commit
3542ded
·
verified ·
1 Parent(s): 30d9bde

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -2
model.py CHANGED
@@ -342,6 +342,9 @@ class rotary(nn.Module):
342
  self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
343
  freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
344
  self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
 
 
 
345
 
346
  def update_base(self, pitch):
347
  theta = pitch.squeeze(0).to(device, dtype)
@@ -400,11 +403,12 @@ class rotary(nn.Module):
400
  if f0 is not None:
401
  f0_mean = f0.mean()
402
  theta = f0_mean + 1e-8
403
- freqs = 1.0 / (theta ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype)[:(self.dim // 2)].float() / self.dim))
404
  if "rotary1" in self.debug:
405
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
406
  else:
407
- freqs = self.freqs
 
408
  freqs = t[:, None] * freqs[None, :]
409
  if self.radii:
410
  if f0 is not None:
 
342
  self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
343
  freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
344
  self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
345
+
346
+ freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
347
+ self.freqb = nn.Parameter(torch.tensor(freqb, device=device, dtype=dtype), requires_grad=True)
348
 
349
  def update_base(self, pitch):
350
  theta = pitch.squeeze(0).to(device, dtype)
 
403
  if f0 is not None:
404
  f0_mean = f0.mean()
405
  theta = f0_mean + 1e-8
406
+ freqs = self.freqb
407
  if "rotary1" in self.debug:
408
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
409
  else:
410
+ freqs = self.freqb
411
+
412
  freqs = t[:, None] * freqs[None, :]
413
  if self.radii:
414
  if f0 is not None: