Update model.py
Browse files
model.py
CHANGED
@@ -296,10 +296,10 @@ class rotary(nn.Module):
|
|
296 |
idx = (idx * frames).long().clamp(0, length - 1)
|
297 |
f0 = f0[idx]
|
298 |
f0_norm = (f0 - f0.mean()) / (f0.std() + 1e-8)
|
299 |
-
|
300 |
-
|
301 |
-
diff = f0_norm[:, None] - f0_norm[None, :]
|
302 |
-
f0_sim = torch.exp(-diff.pow(2))
|
303 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
304 |
|
305 |
def f0proj(self, f0):
|
|
|
296 |
idx = (idx * frames).long().clamp(0, length - 1)
|
297 |
f0 = f0[idx]
|
298 |
f0_norm = (f0 - f0.mean()) / (f0.std() + 1e-8)
|
299 |
+
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
300 |
+
f0_norm.unsqueeze(1)))
|
301 |
+
# diff = f0_norm[:, None] - f0_norm[None, :]
|
302 |
+
# f0_sim = torch.exp(-diff.pow(2))
|
303 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
304 |
|
305 |
def f0proj(self, f0):
|