Update model.py
Browse files
model.py
CHANGED
@@ -411,6 +411,23 @@ class rotary(nn.Module):
|
|
411 |
idx = (idx * frames).long().clamp(0, length - 1)
|
412 |
return f0[idx, :]
|
413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
415 |
f0 = enc.get("f0", None) if enc is not None else None
|
416 |
|
@@ -427,13 +444,16 @@ class rotary(nn.Module):
|
|
427 |
freqs = self.inv_freq
|
428 |
f0_mean = f0.mean()
|
429 |
theta = f0_mean + 1e-8
|
430 |
-
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
431 |
|
432 |
if "rotary1" in self.debug:
|
433 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
434 |
else:
|
435 |
freqs = self.inv_freq
|
436 |
freqs = t[:, None] * freqs[None, :]
|
|
|
|
|
|
|
437 |
if self.radii:
|
438 |
if f0 is not None:
|
439 |
radius = self.align_f0(f0, ctx)
|
|
|
411 |
idx = (idx * frames).long().clamp(0, length - 1)
|
412 |
return f0[idx, :]
|
413 |
|
414 |
+
# def orthogonal(self, dims, i, j, theta):
|
415 |
+
# R = torch.eye(dims).to(theta.device)
|
416 |
+
# R[i, i] = torch.cos(theta)
|
417 |
+
# R[i, j] = -torch.sin(theta)
|
418 |
+
# R[j, i] = torch.sin(theta)
|
419 |
+
# R[j, j] = torch.cos(theta)
|
420 |
+
# R = torch.eye(dims).to(theta.device) - 2 * torch.outer(R, R) / torch.dot(R, R)
|
421 |
+
# return R
|
422 |
+
|
423 |
+
# def orthogonal_regularization_term(self):
|
424 |
+
# loss = torch.tensor(0.0, device=self.r_matrix.device)
|
425 |
+
# if self.r_matrix.requires_grad:
|
426 |
+
# product = torch.matmul(self.r_matrix, self.r_matrix.t())
|
427 |
+
# identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
|
428 |
+
# loss = ((product - identity) ** 2).sum()
|
429 |
+
# return self.orthogonal_reg_weight * loss
|
430 |
+
|
431 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
432 |
f0 = enc.get("f0", None) if enc is not None else None
|
433 |
|
|
|
444 |
freqs = self.inv_freq
|
445 |
f0_mean = f0.mean()
|
446 |
theta = f0_mean + 1e-8
|
447 |
+
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
448 |
|
449 |
if "rotary1" in self.debug:
|
450 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
451 |
else:
|
452 |
freqs = self.inv_freq
|
453 |
freqs = t[:, None] * freqs[None, :]
|
454 |
+
|
455 |
+
# sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
|
456 |
+
|
457 |
if self.radii:
|
458 |
if f0 is not None:
|
459 |
radius = self.align_f0(f0, ctx)
|