Update model.py
Browse files
model.py
CHANGED
@@ -34,6 +34,21 @@ dtype = torch.float32
|
|
34 |
warnings.filterwarnings("ignore")
|
35 |
logging.basicConfig(level=logging.ERROR)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
extractor = None
|
38 |
tokenizer = None
|
39 |
optimizer = None
|
@@ -308,14 +323,33 @@ def align_f0(f0, target_length, method='nearest', device=device, dtype=dtype):
|
|
308 |
result = result.squeeze(0)
|
309 |
return result.to(dtype)
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
class rotary(nn.Module):
|
312 |
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
|
313 |
use_pbias=False, spec_shape=None):
|
314 |
super().__init__()
|
315 |
|
316 |
self.use_pbias = use_pbias
|
317 |
-
use_2d_axial = False
|
318 |
-
self.spec_shape = spec_shape
|
319 |
self.last_f0_theta = None
|
320 |
self.debug = debug
|
321 |
self._counter = 0
|
@@ -324,34 +358,23 @@ class rotary(nn.Module):
|
|
324 |
self.head_dim = dims // head
|
325 |
self.max_ctx = max_ctx
|
326 |
self.radii = radii
|
327 |
-
f0_factor = 0.5
|
328 |
self.learned_adaptation: bool = False
|
329 |
radius = 1
|
330 |
dim = self.head_dim
|
331 |
self.dim = dim
|
332 |
|
333 |
-
|
334 |
-
self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=device, dtype=dtype), requires_grad=True)
|
335 |
-
else:
|
336 |
-
self.register_buffer('f0_scale', torch.tensor(f0_factor))
|
337 |
-
|
338 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
# self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
343 |
-
# freq_data = 1.0 / (theta ** (torch.arange(start=0, end=dim, step=2).float() / dim))
|
344 |
-
# self.inv_freq = nn.Parameter(freq_data, requires_grad=True)
|
345 |
-
|
346 |
-
freqb = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
347 |
-
self.inv_freq = nn.Parameter(torch.tensor(freqb, device=device, dtype=dtype), requires_grad=True)
|
348 |
|
349 |
-
def update_base(self,
|
350 |
-
|
351 |
-
|
352 |
-
inv_freq = 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
353 |
self.inv_freq.data.copy_(inv_freq)
|
354 |
-
self.theta.data.copy_(
|
355 |
|
356 |
def get_pitch_bias(self, f0):
|
357 |
if f0 is None:
|
@@ -401,14 +424,15 @@ class rotary(nn.Module):
|
|
401 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
402 |
|
403 |
if f0 is not None:
|
|
|
404 |
f0_mean = f0.mean()
|
405 |
theta = f0_mean + 1e-8
|
406 |
-
freqs =
|
|
|
407 |
if "rotary1" in self.debug:
|
408 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
409 |
else:
|
410 |
freqs = self.inv_freq
|
411 |
-
|
412 |
freqs = t[:, None] * freqs[None, :]
|
413 |
if self.radii:
|
414 |
if f0 is not None:
|
|
|
34 |
warnings.filterwarnings("ignore")
|
35 |
logging.basicConfig(level=logging.ERROR)
|
36 |
|
37 |
+
from rich.traceback import install
|
38 |
+
install(show_locals=True)
|
39 |
+
|
40 |
+
import pretty_errors
|
41 |
+
pretty_errors.configure(
|
42 |
+
separator_character = '*',
|
43 |
+
filename_display = pretty_errors.FILENAME_EXTENDED,
|
44 |
+
line_number_first = True,
|
45 |
+
display_link = True,
|
46 |
+
lines_before = 5,
|
47 |
+
lines_after = 2,
|
48 |
+
line_color = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
|
49 |
+
code_color = ' ' + pretty_errors.default_config.line_color,
|
50 |
+
)
|
51 |
+
|
52 |
extractor = None
|
53 |
tokenizer = None
|
54 |
optimizer = None
|
|
|
323 |
result = result.squeeze(0)
|
324 |
return result.to(dtype)
|
325 |
|
326 |
+
# def update_base(self, f0):
|
327 |
+
# f0 = f0.to(device, dtype)
|
328 |
+
# f0_mean = f0.mean() + 1e-8
|
329 |
+
|
330 |
+
# # Standard RoPE calculation (keep this)
|
331 |
+
# theta_freqs = 1.0 / (f0_mean ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype)[:(self.dim // 2)].float() / self.dim))
|
332 |
+
|
333 |
+
# # Direct f0-adapted mel scale (new part)
|
334 |
+
# center_freq = f0_mean
|
335 |
+
# min_freq = center_freq * 0.25 # Lower bound
|
336 |
+
# max_freq = center_freq * 4.0 # Upper bound
|
337 |
+
|
338 |
+
# # Direct mel calculation centered on f0
|
339 |
+
# mel_min = 2595 * torch.log10(1 + min_freq/700)
|
340 |
+
# mel_max = 2595 * torch.log10(1 + max_freq/700)
|
341 |
+
# mel_freqs = 700 * (torch.pow(10, torch.linspace(mel_min, mel_max, self.dim//2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
342 |
+
|
343 |
+
# # Use a weighted combination
|
344 |
+
# self.inv_freq.data.copy_(0.5 * theta_freqs + 0.5 * mel_freqs)
|
345 |
+
# self.theta.data.copy_(f0_mean)
|
346 |
+
|
347 |
class rotary(nn.Module):
|
348 |
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
|
349 |
use_pbias=False, spec_shape=None):
|
350 |
super().__init__()
|
351 |
|
352 |
self.use_pbias = use_pbias
|
|
|
|
|
353 |
self.last_f0_theta = None
|
354 |
self.debug = debug
|
355 |
self._counter = 0
|
|
|
358 |
self.head_dim = dims // head
|
359 |
self.max_ctx = max_ctx
|
360 |
self.radii = radii
|
|
|
361 |
self.learned_adaptation: bool = False
|
362 |
radius = 1
|
363 |
dim = self.head_dim
|
364 |
self.dim = dim
|
365 |
|
366 |
+
theta = torch.tensor(theta, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
367 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
368 |
+
self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
369 |
+
inv_freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
370 |
+
self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
+
def update_base(self, f0):
|
373 |
+
f0 = f0.squeeze(0).to(device, dtype)
|
374 |
+
theta = f0.mean() + 1e-8
|
375 |
+
inv_freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
376 |
self.inv_freq.data.copy_(inv_freq)
|
377 |
+
self.theta.data.copy_(theta)
|
378 |
|
379 |
def get_pitch_bias(self, f0):
|
380 |
if f0 is None:
|
|
|
424 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
425 |
|
426 |
if f0 is not None:
|
427 |
+
freqs = self.inv_freq
|
428 |
f0_mean = f0.mean()
|
429 |
theta = f0_mean + 1e-8
|
430 |
+
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
431 |
+
|
432 |
if "rotary1" in self.debug:
|
433 |
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
434 |
else:
|
435 |
freqs = self.inv_freq
|
|
|
436 |
freqs = t[:, None] * freqs[None, :]
|
437 |
if self.radii:
|
438 |
if f0 is not None:
|