Sin2pi commited on
Commit
d2f4343
·
verified ·
1 Parent(s): 08be513

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -8
model.py CHANGED
@@ -337,14 +337,14 @@ class rotary(nn.Module):
337
 
338
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
339
 
340
- freqs = 1. / (theta ** (torch.arange(0, dim, 2, device=device, dtype=dtype)[:(dim // 2)].float() / dims))
341
- self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
342
- self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
343
- freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
344
- self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
345
 
346
  freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
347
- self.freqb = nn.Parameter(torch.tensor(freqb, device=device, dtype=dtype), requires_grad=True)
348
 
349
  def update_base(self, pitch):
350
  theta = pitch.squeeze(0).to(device, dtype)
@@ -403,11 +403,11 @@ class rotary(nn.Module):
403
  if f0 is not None:
404
  f0_mean = f0.mean()
405
  theta = f0_mean + 1e-8
406
- freqs = self.freqb
407
  if "rotary1" in self.debug:
408
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
409
  else:
410
- freqs = self.freqb
411
 
412
  freqs = t[:, None] * freqs[None, :]
413
  if self.radii:
 
337
 
338
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
339
 
340
+ # freqs = 1. / (theta ** (torch.arange(0, dim, 2, device=device, dtype=dtype)[:(dim // 2)].float() / dims))
341
+ # self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
342
+ # self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
343
+ # freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
344
+ # self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
345
 
346
  freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
347
+ self.inv_freq = nn.Parameter(torch.tensor(freqb, device=device, dtype=dtype), requires_grad=True)
348
 
349
  def update_base(self, pitch):
350
  theta = pitch.squeeze(0).to(device, dtype)
 
403
  if f0 is not None:
404
  f0_mean = f0.mean()
405
  theta = f0_mean + 1e-8
406
+ freqs = self.inv_freq
407
  if "rotary1" in self.debug:
408
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
409
  else:
410
+ freqs = self.inv_freq
411
 
412
  freqs = t[:, None] * freqs[None, :]
413
  if self.radii: