Sin2pi commited on
Commit
cc3d368
·
verified ·
1 Parent(s): 8b434a5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +13 -38
model.py CHANGED
@@ -287,11 +287,19 @@ class rotary(nn.Module):
287
  def get_bias(self, f0, ctx):
288
  if f0 is None:
289
  return None
290
- f0 = self.align_f0a(f0, ctx)
291
- f0_flat = f0.squeeze().float()
292
- f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
293
- f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
294
- f0_norm.unsqueeze(1)))
 
 
 
 
 
 
 
 
295
  return f0_sim.unsqueeze(0).unsqueeze(0)
296
 
297
  def f0proj(self, f0):
@@ -313,38 +321,6 @@ class rotary(nn.Module):
313
  frames = length / ctx
314
  idx = torch.arange(ctx, device=f0.device)
315
  return f0[idx]
316
-
317
- def align_f0a(self, f0, ctx):
318
- if f0.dim() == 3:
319
- batch, length, dims = f0.shape
320
- if length == ctx:
321
- f0 = f0
322
- else:
323
- frames = length / ctx
324
- idx = torch.arange(ctx, device=f0.device)
325
- idx = (idx * frames).long().clamp(0, length - 1)
326
- f0 = f0[:, idx, :]
327
- f0 = f0.mean(dim=(0, -1))
328
- return f0
329
- if f0.dim() == 2:
330
- length, dims = f0.shape
331
- if length == ctx:
332
- f0 = f0
333
- else:
334
- frames = length / ctx
335
- idx = torch.arange(ctx, device=f0.device)
336
- idx = (idx * frames).long().clamp(0, length - 1)
337
- f0 = f0[idx, :]
338
- f0 = f0.mean(dim=-1)
339
- return f0
340
- if f0.dim() == 1:
341
- length = f0.shape[0]
342
- if length == ctx:
343
- return f0
344
- frames = length / ctx
345
- idx = torch.arange(ctx, device=f0.device)
346
- idx = (idx * frames).long().clamp(0, length - 1)
347
- return f0[idx]
348
 
349
  def align_f0(self, ctx, f0):
350
  f0 = self.f0proj(f0)
@@ -552,7 +528,6 @@ class MultiheadA(nn.Module):
552
  f0 = enc.get("f0", None) if enc is not None else None
553
  pbias = self.rope.get_bias(f0, q2)
554
  if pbias is not None:
555
- # print(f"pbias shape: {pbias.shape}, qk shape: {qk.shape}")
556
  qk = qk + pbias
557
  token_ids = k[:, :, :, 0]
558
  zscale = torch.ones_like(token_ids)
 
287
  def get_bias(self, f0, ctx):
288
  if f0 is None:
289
  return None
290
+ if f0.dim() == 1:
291
+ length = f0.shape[0]
292
+ if length == ctx:
293
+ return f0
294
+ frames = length / ctx
295
+ idx = torch.arange(ctx, device=f0.device)
296
+ idx = (idx * frames).long().clamp(0, length - 1)
297
+ f0 = f0[idx]
298
+ f0_norm = (f0 - f0.mean()) / (f0.std() + 1e-8)
299
+ # f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
300
+ # f0_norm.unsqueeze(1)))
301
+ diff = f0_norm[:, None] - f0_norm[None, :]
302
+ f0_sim = torch.exp(-diff.pow(2))
303
  return f0_sim.unsqueeze(0).unsqueeze(0)
304
 
305
  def f0proj(self, f0):
 
321
  frames = length / ctx
322
  idx = torch.arange(ctx, device=f0.device)
323
  return f0[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  def align_f0(self, ctx, f0):
326
  f0 = self.f0proj(f0)
 
528
  f0 = enc.get("f0", None) if enc is not None else None
529
  pbias = self.rope.get_bias(f0, q2)
530
  if pbias is not None:
 
531
  qk = qk + pbias
532
  token_ids = k[:, :, :, 0]
533
  zscale = torch.ones_like(token_ids)