Update model.py
Browse files
model.py
CHANGED
@@ -284,9 +284,10 @@ class rotary(nn.Module):
|
|
284 |
self.freqs.data.copy_(freqs)
|
285 |
self.theta.data.copy_(theta)
|
286 |
|
287 |
-
def
|
288 |
if f0 is None:
|
289 |
return None
|
|
|
290 |
f0_flat = f0.squeeze().float()
|
291 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
292 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
@@ -313,6 +314,38 @@ class rotary(nn.Module):
|
|
313 |
idx = torch.arange(ctx, device=f0.device)
|
314 |
return f0[idx]
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
def align_f0(self, ctx, f0):
|
317 |
f0 = self.f0proj(f0)
|
318 |
if f0.dim() == 3:
|
@@ -359,7 +392,7 @@ class rotary(nn.Module):
|
|
359 |
else:
|
360 |
f0 = f0.view(-1)
|
361 |
|
362 |
-
if f0 is not None:
|
363 |
f0_mean = f0.mean()
|
364 |
theta = f0_mean + self.theta
|
365 |
else:
|
@@ -371,7 +404,7 @@ class rotary(nn.Module):
|
|
371 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
372 |
|
373 |
freqs = t[:, None] * freqs[None, :]
|
374 |
-
if self.radii and f0 is not None:
|
375 |
radius = f0.to(device, dtype)
|
376 |
L = radius.shape[0]
|
377 |
if L != ctx:
|
@@ -463,7 +496,8 @@ class MultiheadA(nn.Module):
|
|
463 |
dims=dims,
|
464 |
head=head,
|
465 |
debug=debug,
|
466 |
-
radii=False,
|
|
|
467 |
)
|
468 |
else:
|
469 |
self.rope = None
|
@@ -515,10 +549,11 @@ class MultiheadA(nn.Module):
|
|
515 |
|
516 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
517 |
if self.rope.use_pbias:
|
518 |
-
f0 = enc.get("f0", None) if enc is not None else None
|
519 |
-
pbias = self.rope.
|
520 |
if pbias is not None:
|
521 |
-
|
|
|
522 |
token_ids = k[:, :, :, 0]
|
523 |
zscale = torch.ones_like(token_ids)
|
524 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
@@ -959,13 +994,9 @@ class TextDecoder(nn.Module):
|
|
959 |
mask = self.mask[:x.shape[1], :x.shape[1]]
|
960 |
x = self.token(x) + self.positional[:x.shape[1]]
|
961 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
962 |
-
|
963 |
-
# ctx = x.shape[1]
|
964 |
-
# freqs = self.rotary(ctx)
|
965 |
-
# x = self.rotary.apply_rotary(x, freqs)
|
966 |
|
967 |
for block in self.block:
|
968 |
-
x = block(x, xa=None, mask=mask, enc=
|
969 |
|
970 |
for f in order:
|
971 |
if f in enc:
|
|
|
284 |
self.freqs.data.copy_(freqs)
|
285 |
self.theta.data.copy_(theta)
|
286 |
|
287 |
+
def get_bias(self, f0, ctx):
|
288 |
if f0 is None:
|
289 |
return None
|
290 |
+
f0 = self.align_f0a(f0, ctx)
|
291 |
f0_flat = f0.squeeze().float()
|
292 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
293 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
|
|
314 |
idx = torch.arange(ctx, device=f0.device)
|
315 |
return f0[idx]
|
316 |
|
317 |
+
def align_f0a(self, f0, ctx):
|
318 |
+
if f0.dim() == 3:
|
319 |
+
batch, length, dims = f0.shape
|
320 |
+
if length == ctx:
|
321 |
+
f0 = f0
|
322 |
+
else:
|
323 |
+
frames = length / ctx
|
324 |
+
idx = torch.arange(ctx, device=f0.device)
|
325 |
+
idx = (idx * frames).long().clamp(0, length - 1)
|
326 |
+
f0 = f0[:, idx, :]
|
327 |
+
f0 = f0.mean(dim=(0, -1))
|
328 |
+
return f0
|
329 |
+
if f0.dim() == 2:
|
330 |
+
length, dims = f0.shape
|
331 |
+
if length == ctx:
|
332 |
+
f0 = f0
|
333 |
+
else:
|
334 |
+
frames = length / ctx
|
335 |
+
idx = torch.arange(ctx, device=f0.device)
|
336 |
+
idx = (idx * frames).long().clamp(0, length - 1)
|
337 |
+
f0 = f0[idx, :]
|
338 |
+
f0 = f0.mean(dim=-1)
|
339 |
+
return f0
|
340 |
+
if f0.dim() == 1:
|
341 |
+
length = f0.shape[0]
|
342 |
+
if length == ctx:
|
343 |
+
return f0
|
344 |
+
frames = length / ctx
|
345 |
+
idx = torch.arange(ctx, device=f0.device)
|
346 |
+
idx = (idx * frames).long().clamp(0, length - 1)
|
347 |
+
return f0[idx]
|
348 |
+
|
349 |
def align_f0(self, ctx, f0):
|
350 |
f0 = self.f0proj(f0)
|
351 |
if f0.dim() == 3:
|
|
|
392 |
else:
|
393 |
f0 = f0.view(-1)
|
394 |
|
395 |
+
if f0 is not None and layer == "encoder": #rethink this
|
396 |
f0_mean = f0.mean()
|
397 |
theta = f0_mean + self.theta
|
398 |
else:
|
|
|
404 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
405 |
|
406 |
freqs = t[:, None] * freqs[None, :]
|
407 |
+
if self.radii and f0 is not None and layer == "encoder": #this too
|
408 |
radius = f0.to(device, dtype)
|
409 |
L = radius.shape[0]
|
410 |
if L != ctx:
|
|
|
496 |
dims=dims,
|
497 |
head=head,
|
498 |
debug=debug,
|
499 |
+
radii=True if "radii" in debug else False,
|
500 |
+
use_pbias=True if "pbias" in debug else False,
|
501 |
)
|
502 |
else:
|
503 |
self.rope = None
|
|
|
549 |
|
550 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
551 |
if self.rope.use_pbias:
|
552 |
+
f0 = enc.get("f0", None) if enc is not None else None
|
553 |
+
pbias = self.rope.get_bias(f0, q2)
|
554 |
if pbias is not None:
|
555 |
+
# print(f"pbias shape: {pbias.shape}, qk shape: {qk.shape}")
|
556 |
+
qk = qk + pbias
|
557 |
token_ids = k[:, :, :, 0]
|
558 |
zscale = torch.ones_like(token_ids)
|
559 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
|
|
994 |
mask = self.mask[:x.shape[1], :x.shape[1]]
|
995 |
x = self.token(x) + self.positional[:x.shape[1]]
|
996 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
|
|
|
|
|
|
997 |
|
998 |
for block in self.block:
|
999 |
+
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
1000 |
|
1001 |
for f in order:
|
1002 |
if f in enc:
|