Sin2pi commited on
Commit
8e82b27
·
verified ·
1 Parent(s): eabe2b7

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +43 -12
model.py CHANGED
@@ -284,9 +284,10 @@ class rotary(nn.Module):
284
  self.freqs.data.copy_(freqs)
285
  self.theta.data.copy_(theta)
286
 
287
- def get_pitch_bias(self, f0):
288
  if f0 is None:
289
  return None
 
290
  f0_flat = f0.squeeze().float()
291
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
292
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
@@ -313,6 +314,38 @@ class rotary(nn.Module):
313
  idx = torch.arange(ctx, device=f0.device)
314
  return f0[idx]
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  def align_f0(self, ctx, f0):
317
  f0 = self.f0proj(f0)
318
  if f0.dim() == 3:
@@ -359,7 +392,7 @@ class rotary(nn.Module):
359
  else:
360
  f0 = f0.view(-1)
361
 
362
- if f0 is not None:
363
  f0_mean = f0.mean()
364
  theta = f0_mean + self.theta
365
  else:
@@ -371,7 +404,7 @@ class rotary(nn.Module):
371
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
372
 
373
  freqs = t[:, None] * freqs[None, :]
374
- if self.radii and f0 is not None:
375
  radius = f0.to(device, dtype)
376
  L = radius.shape[0]
377
  if L != ctx:
@@ -463,7 +496,8 @@ class MultiheadA(nn.Module):
463
  dims=dims,
464
  head=head,
465
  debug=debug,
466
- radii=False,
 
467
  )
468
  else:
469
  self.rope = None
@@ -515,10 +549,11 @@ class MultiheadA(nn.Module):
515
 
516
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
517
  if self.rope.use_pbias:
518
- f0 = enc.get("f0", None) if enc is not None else None
519
- pbias = self.rope.use_pbias(f0)
520
  if pbias is not None:
521
- qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
 
522
  token_ids = k[:, :, :, 0]
523
  zscale = torch.ones_like(token_ids)
524
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
@@ -959,13 +994,9 @@ class TextDecoder(nn.Module):
959
  mask = self.mask[:x.shape[1], :x.shape[1]]
960
  x = self.token(x) + self.positional[:x.shape[1]]
961
  x = F.dropout(x, p=self.dropout, training=self.training)
962
-
963
- # ctx = x.shape[1]
964
- # freqs = self.rotary(ctx)
965
- # x = self.rotary.apply_rotary(x, freqs)
966
 
967
  for block in self.block:
968
- x = block(x, xa=None, mask=mask, enc=None, layer=layer)
969
 
970
  for f in order:
971
  if f in enc:
 
284
  self.freqs.data.copy_(freqs)
285
  self.theta.data.copy_(theta)
286
 
287
+ def get_bias(self, f0, ctx):
288
  if f0 is None:
289
  return None
290
+ f0 = self.align_f0a(f0, ctx)
291
  f0_flat = f0.squeeze().float()
292
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
293
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
 
314
  idx = torch.arange(ctx, device=f0.device)
315
  return f0[idx]
316
 
317
+ def align_f0a(self, f0, ctx):
318
+ if f0.dim() == 3:
319
+ batch, length, dims = f0.shape
320
+ if length == ctx:
321
+ f0 = f0
322
+ else:
323
+ frames = length / ctx
324
+ idx = torch.arange(ctx, device=f0.device)
325
+ idx = (idx * frames).long().clamp(0, length - 1)
326
+ f0 = f0[:, idx, :]
327
+ f0 = f0.mean(dim=(0, -1))
328
+ return f0
329
+ if f0.dim() == 2:
330
+ length, dims = f0.shape
331
+ if length == ctx:
332
+ f0 = f0
333
+ else:
334
+ frames = length / ctx
335
+ idx = torch.arange(ctx, device=f0.device)
336
+ idx = (idx * frames).long().clamp(0, length - 1)
337
+ f0 = f0[idx, :]
338
+ f0 = f0.mean(dim=-1)
339
+ return f0
340
+ if f0.dim() == 1:
341
+ length = f0.shape[0]
342
+ if length == ctx:
343
+ return f0
344
+ frames = length / ctx
345
+ idx = torch.arange(ctx, device=f0.device)
346
+ idx = (idx * frames).long().clamp(0, length - 1)
347
+ return f0[idx]
348
+
349
  def align_f0(self, ctx, f0):
350
  f0 = self.f0proj(f0)
351
  if f0.dim() == 3:
 
392
  else:
393
  f0 = f0.view(-1)
394
 
395
+ if f0 is not None and layer == "encoder": #rethink this
396
  f0_mean = f0.mean()
397
  theta = f0_mean + self.theta
398
  else:
 
404
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
405
 
406
  freqs = t[:, None] * freqs[None, :]
407
+ if self.radii and f0 is not None and layer == "encoder": #this too
408
  radius = f0.to(device, dtype)
409
  L = radius.shape[0]
410
  if L != ctx:
 
496
  dims=dims,
497
  head=head,
498
  debug=debug,
499
+ radii=True if "radii" in debug else False,
500
+ use_pbias=True if "pbias" in debug else False,
501
  )
502
  else:
503
  self.rope = None
 
549
 
550
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
551
  if self.rope.use_pbias:
552
+ f0 = enc.get("f0", None) if enc is not None else None
553
+ pbias = self.rope.get_bias(f0, q2)
554
  if pbias is not None:
555
+ # print(f"pbias shape: {pbias.shape}, qk shape: {qk.shape}")
556
+ qk = qk + pbias
557
  token_ids = k[:, :, :, 0]
558
  zscale = torch.ones_like(token_ids)
559
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
 
994
  mask = self.mask[:x.shape[1], :x.shape[1]]
995
  x = self.token(x) + self.positional[:x.shape[1]]
996
  x = F.dropout(x, p=self.dropout, training=self.training)
 
 
 
 
997
 
998
  for block in self.block:
999
+ x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
1000
 
1001
  for f in order:
1002
  if f in enc: