Sin2pi commited on
Commit
9af949e
·
verified ·
1 Parent(s): d2f4343

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +48 -24
model.py CHANGED
@@ -34,6 +34,21 @@ dtype = torch.float32
34
  warnings.filterwarnings("ignore")
35
  logging.basicConfig(level=logging.ERROR)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  extractor = None
38
  tokenizer = None
39
  optimizer = None
@@ -308,14 +323,33 @@ def align_f0(f0, target_length, method='nearest', device=device, dtype=dtype):
308
  result = result.squeeze(0)
309
  return result.to(dtype)
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  class rotary(nn.Module):
312
  def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
313
  use_pbias=False, spec_shape=None):
314
  super().__init__()
315
 
316
  self.use_pbias = use_pbias
317
- use_2d_axial = False
318
- self.spec_shape = spec_shape
319
  self.last_f0_theta = None
320
  self.debug = debug
321
  self._counter = 0
@@ -324,34 +358,23 @@ class rotary(nn.Module):
324
  self.head_dim = dims // head
325
  self.max_ctx = max_ctx
326
  self.radii = radii
327
- f0_factor = 0.5
328
  self.learned_adaptation: bool = False
329
  radius = 1
330
  dim = self.head_dim
331
  self.dim = dim
332
 
333
- if self.learned_adaptation:
334
- self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=device, dtype=dtype), requires_grad=True)
335
- else:
336
- self.register_buffer('f0_scale', torch.tensor(f0_factor))
337
-
338
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
339
-
340
- # freqs = 1. / (theta ** (torch.arange(0, dim, 2, device=device, dtype=dtype)[:(dim // 2)].float() / dims))
341
- # self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
342
- # self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
343
- # freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
344
- # self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
345
-
346
- freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
347
- self.inv_freq = nn.Parameter(torch.tensor(freqb, device=device, dtype=dtype), requires_grad=True)
348
 
349
- def update_base(self, pitch):
350
- theta = pitch.squeeze(0).to(device, dtype)
351
- f0_mean = theta.mean() + 1e-8
352
- inv_freq = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
353
  self.inv_freq.data.copy_(inv_freq)
354
- self.theta.data.copy_(f0_mean)
355
 
356
  def get_pitch_bias(self, f0):
357
  if f0 is None:
@@ -401,14 +424,15 @@ class rotary(nn.Module):
401
  t = torch.arange(ctx, device=device, dtype=dtype)
402
 
403
  if f0 is not None:
 
404
  f0_mean = f0.mean()
405
  theta = f0_mean + 1e-8
406
- freqs = self.inv_freq
 
407
  if "rotary1" in self.debug:
408
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
409
  else:
410
  freqs = self.inv_freq
411
-
412
  freqs = t[:, None] * freqs[None, :]
413
  if self.radii:
414
  if f0 is not None:
 
34
  warnings.filterwarnings("ignore")
35
  logging.basicConfig(level=logging.ERROR)
36
 
37
+ from rich.traceback import install
38
+ install(show_locals=True)
39
+
40
+ import pretty_errors
41
+ pretty_errors.configure(
42
+ separator_character = '*',
43
+ filename_display = pretty_errors.FILENAME_EXTENDED,
44
+ line_number_first = True,
45
+ display_link = True,
46
+ lines_before = 5,
47
+ lines_after = 2,
48
+ line_color = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
49
+ code_color = ' ' + pretty_errors.default_config.line_color,
50
+ )
51
+
52
  extractor = None
53
  tokenizer = None
54
  optimizer = None
 
323
  result = result.squeeze(0)
324
  return result.to(dtype)
325
 
326
+ # def update_base(self, f0):
327
+ # f0 = f0.to(device, dtype)
328
+ # f0_mean = f0.mean() + 1e-8
329
+
330
+ # # Standard RoPE calculation (keep this)
331
+ # theta_freqs = 1.0 / (f0_mean ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype)[:(self.dim // 2)].float() / self.dim))
332
+
333
+ # # Direct f0-adapted mel scale (new part)
334
+ # center_freq = f0_mean
335
+ # min_freq = center_freq * 0.25 # Lower bound
336
+ # max_freq = center_freq * 4.0 # Upper bound
337
+
338
+ # # Direct mel calculation centered on f0
339
+ # mel_min = 2595 * torch.log10(1 + min_freq/700)
340
+ # mel_max = 2595 * torch.log10(1 + max_freq/700)
341
+ # mel_freqs = 700 * (torch.pow(10, torch.linspace(mel_min, mel_max, self.dim//2, device=device, dtype=dtype) / 2595) - 1) / 1000
342
+
343
+ # # Use a weighted combination
344
+ # self.inv_freq.data.copy_(0.5 * theta_freqs + 0.5 * mel_freqs)
345
+ # self.theta.data.copy_(f0_mean)
346
+
347
  class rotary(nn.Module):
348
  def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
349
  use_pbias=False, spec_shape=None):
350
  super().__init__()
351
 
352
  self.use_pbias = use_pbias
 
 
353
  self.last_f0_theta = None
354
  self.debug = debug
355
  self._counter = 0
 
358
  self.head_dim = dims // head
359
  self.max_ctx = max_ctx
360
  self.radii = radii
 
361
  self.learned_adaptation: bool = False
362
  radius = 1
363
  dim = self.head_dim
364
  self.dim = dim
365
 
366
+ theta = torch.tensor(theta, device=device, dtype=dtype)
 
 
 
 
367
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
368
+ self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
369
+ inv_freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
370
+ self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
 
 
 
 
 
 
371
 
372
+ def update_base(self, f0):
373
+ f0 = f0.squeeze(0).to(device, dtype)
374
+ theta = f0.mean() + 1e-8
375
+ inv_freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
376
  self.inv_freq.data.copy_(inv_freq)
377
+ self.theta.data.copy_(theta)
378
 
379
  def get_pitch_bias(self, f0):
380
  if f0 is None:
 
424
  t = torch.arange(ctx, device=device, dtype=dtype)
425
 
426
  if f0 is not None:
427
+ freqs = self.inv_freq
428
  f0_mean = f0.mean()
429
  theta = f0_mean + 1e-8
430
+ freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
431
+
432
  if "rotary1" in self.debug:
433
  print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
434
  else:
435
  freqs = self.inv_freq
 
436
  freqs = t[:, None] * freqs[None, :]
437
  if self.radii:
438
  if f0 is not None: