Sin2pi commited on
Commit
9578ddb
·
verified ·
1 Parent(s): b4f3b2e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +21 -1
model.py CHANGED
@@ -411,6 +411,23 @@ class rotary(nn.Module):
411
  idx = (idx * frames).long().clamp(0, length - 1)
412
  return f0[idx, :]
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
415
  f0 = enc.get("f0", None) if enc is not None else None
416
 
@@ -427,13 +444,16 @@ class rotary(nn.Module):
427
  freqs = self.inv_freq
428
  f0_mean = f0.mean()
429
  theta = f0_mean + 1e-8
430
- freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
431
 
432
  if "rotary1" in self.debug:
433
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
434
  else:
435
  freqs = self.inv_freq
436
  freqs = t[:, None] * freqs[None, :]
 
 
 
437
  if self.radii:
438
  if f0 is not None:
439
  radius = self.align_f0(f0, ctx)
 
411
  idx = (idx * frames).long().clamp(0, length - 1)
412
  return f0[idx, :]
413
 
414
+ # def orthogonal(self, dims, i, j, theta):
415
+ # R = torch.eye(dims).to(theta.device)
416
+ # R[i, i] = torch.cos(theta)
417
+ # R[i, j] = -torch.sin(theta)
418
+ # R[j, i] = torch.sin(theta)
419
+ # R[j, j] = torch.cos(theta)
420
+ # R = torch.eye(dims).to(theta.device) - 2 * torch.outer(R, R) / torch.dot(R, R)
421
+ # return R
422
+
423
+ # def orthogonal_regularization_term(self):
424
+ # loss = torch.tensor(0.0, device=self.r_matrix.device)
425
+ # if self.r_matrix.requires_grad:
426
+ # product = torch.matmul(self.r_matrix, self.r_matrix.t())
427
+ # identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
428
+ # loss = ((product - identity) ** 2).sum()
429
+ # return self.orthogonal_reg_weight * loss
430
+
431
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
432
  f0 = enc.get("f0", None) if enc is not None else None
433
 
 
444
  freqs = self.inv_freq
445
  f0_mean = f0.mean()
446
  theta = f0_mean + 1e-8
447
+ freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
448
 
449
  if "rotary1" in self.debug:
450
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
451
  else:
452
  freqs = self.inv_freq
453
  freqs = t[:, None] * freqs[None, :]
454
+
455
+ # sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
456
+
457
  if self.radii:
458
  if f0 is not None:
459
  radius = self.align_f0(f0, ctx)