Sin2pi commited on
Commit
1f5c16d
·
verified ·
1 Parent(s): 22b781d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +42 -47
model.py CHANGED
@@ -12,6 +12,7 @@ import torch.nn.functional as F
12
  import torch.nn.init as init
13
  from torch import nn, Tensor
14
  import numpy as np
 
15
  import matplotlib.pyplot as plt
16
  from typing import Optional, Dict, Union, List, Tuple, Any
17
  from functools import partial
@@ -249,24 +250,20 @@ def sinusoids(length, channels, max_timescale=10000):
249
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
250
 
251
  class rotary(nn.Module):
252
- def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False, spec_shape=None):
253
  super(rotary, self).__init__()
254
 
 
255
  self.dims = dims
256
  self.head = head
257
  self.head_dim = dims // head
258
- self.dim = self.head_dim
259
- self.max_ctx = max_ctx
260
- self.theta = theta
261
  self.radii = radii
262
- self.pitch_scale = 0.1
263
- self.use_pbias = use_pbias
264
- self.spec_shape = spec_shape
265
  self.debug = debug
266
  self.counter = 0
267
  self.last_theta = None
268
 
269
- # self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
270
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
271
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
272
  self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
@@ -287,7 +284,8 @@ class rotary(nn.Module):
287
  self.theta.data.copy_(theta)
288
 
289
  def get_pitch_bias(self, f0):
290
- f0 = self.return_f0()
 
291
  f0_flat = f0.squeeze().float()
292
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
293
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
@@ -317,8 +315,6 @@ class rotary(nn.Module):
317
  return f0[idx]
318
 
319
  def align_f0(self, ctx, f0):
320
- # f0 = self.return_f0()
321
- # f0 = self.f0proj(f0)
322
  if f0.dim() == 3:
323
  batch, length, dims = f0.shape
324
  if length == ctx:
@@ -345,24 +341,23 @@ class rotary(nn.Module):
345
  return f0[idx, :]
346
 
347
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
348
- f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
349
- if f0 is not None and f0.dim() == 2:
350
- if f0.shape[0] == 1:
351
- f0 = f0.squeeze(0)
352
- else:
353
- f0 = f0.view(-1)
354
-
355
- if "rot1" in self.debug and self.counter % 100 == 0:
356
- print(f"Rotary forward: {x if x is not None else None}, f0: {f0.shape if f0 is not None else None}")
357
-
358
  if isinstance(x, int):
359
  ctx = x
 
 
360
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
361
  batch, ctx, dims = x.shape
362
  else:
363
  batch, head, ctx, head_dim = x.shape
364
  t = torch.arange(ctx, device=device, dtype=dtype)
365
 
 
 
 
 
 
 
 
366
  if f0 is not None:
367
  f0_mean = f0.mean()
368
  theta = f0_mean + self.theta
@@ -388,6 +383,9 @@ class rotary(nn.Module):
388
  radius = torch.ones_like(freqs)
389
  freqs = torch.polar(radius, freqs)
390
 
 
 
 
391
  if "rot3" in self.debug and self.counter % 100 == 0:
392
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
393
 
@@ -859,13 +857,11 @@ class AudioEncoder(nn.Module):
859
  ),
860
  "envelope": nn.ModuleList(
861
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
862
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
863
- for _ in range(layer)] if "envelope" in features else None
864
  ),
865
  "phase": nn.ModuleList(
866
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
867
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
868
- for _ in range(layer)] if "phase" in features else None
869
  )
870
  })
871
 
@@ -899,7 +895,7 @@ class AudioEncoder(nn.Module):
899
 
900
  class TextDecoder(nn.Module):
901
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
902
- debug: List[str], features: List[str], sequential=False):
903
  super(TextDecoder, self).__init__()
904
 
905
  self.ctx = ctx
@@ -909,7 +905,6 @@ class TextDecoder(nn.Module):
909
  self.debug = debug
910
  self.counter = 0
911
  self.dropout = 0.01
912
- self.sequential = sequential
913
  self.features = features
914
 
915
  self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
@@ -931,7 +926,7 @@ class TextDecoder(nn.Module):
931
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
932
  self.register_buffer("mask", mask, persistent=False)
933
 
934
- def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
935
  enc = dict_to(enc, device, dtype)
936
  x = x.to(device)
937
  bln = self.blend
@@ -943,17 +938,25 @@ class TextDecoder(nn.Module):
943
  x = self.token(x) + self.positional[:x.shape[1]]
944
  x = F.dropout(x, p=self.dropout, training=self.training)
945
 
 
 
 
 
946
  for block in self.block:
947
  x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
948
 
949
  for f in order:
950
  if f in enc:
 
951
  xa = enc[f]
952
  for block in self.blocks[f]:
953
  out = block(x=x, xa=xa, mask=None, enc=enc, layer=layer)
954
 
955
- a = torch.sigmoid(bln[f])
956
- x = a * out + (1 - a) * x
 
 
 
957
 
958
  if "decoder" in self.debug and self.counter % 100 == 0:
959
  print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
@@ -1019,19 +1022,16 @@ class Echo(nn.Module):
1019
  return self.decoder(input_ids, encoder_output)
1020
 
1021
  def forward(self,
1022
- decoder_input_ids=None,
1023
  labels=None,
1024
  waveform: Optional[torch.Tensor]=None,
1025
  input_ids=None,
1026
  spectrogram: torch.Tensor=None,
1027
  pitch: Optional[torch.Tensor]=None,
1028
  f0: Optional[torch.Tensor]=None,
1029
- f0d: Optional[torch.Tensor]=None,
1030
  envelope: Optional[torch.Tensor]=None,
1031
  phase: Optional[torch.Tensor]=None,
1032
  ) -> Dict[str, torch.Tensor]:
1033
 
1034
- decoder_input_ids = input_ids
1035
  encoder_inputs = {}
1036
  if spectrogram is not None:
1037
  encoder_inputs["spectrogram"] = spectrogram
@@ -1046,11 +1046,6 @@ class Echo(nn.Module):
1046
  if f0 is not None:
1047
  encoder_inputs["f0"] = f0
1048
 
1049
-
1050
- # if f0 is not None:
1051
- # f0 = f0.squeeze(0)
1052
- # self.update_base(f0)
1053
-
1054
  encoder_outputs = self.encoder(encoder_inputs)
1055
  logits = self.decoder(input_ids, encoder_outputs)
1056
 
@@ -1063,10 +1058,6 @@ class Echo(nn.Module):
1063
  return {
1064
  "logits": logits,
1065
  "loss": loss,
1066
- # "labels": labels,
1067
- # "input_ids": input_ids,
1068
- # "decoder_input_ids": decoder_input_ids,
1069
- # "encoder_output": encoder_outputs,
1070
  }
1071
 
1072
  @property
@@ -1617,9 +1608,9 @@ def get_training_args(
1617
  per_device_train_batch_size=1,
1618
  per_device_eval_batch_size=1,
1619
  gradient_accumulation_steps=1,
1620
- eval_accumulation_steps=1,
1621
  eval_strategy="steps",
1622
- save_strategy="steps",
1623
  max_steps=max_steps,
1624
  save_steps=save_steps,
1625
  eval_steps=eval_steps,
@@ -1703,7 +1694,7 @@ def main():
1703
  training_args = sanity(sanity_check)
1704
  dataset_config = {
1705
  "spectrogram": True,
1706
- "waveforms": False,
1707
  "pitch": False,
1708
  "downsamples": False,
1709
  "frequency": False,
@@ -1752,6 +1743,10 @@ def main():
1752
  if __name__ == "__main__":
1753
  main()
1754
 
1755
-
1756
-
 
 
 
 
1757
 
 
12
  import torch.nn.init as init
13
  from torch import nn, Tensor
14
  import numpy as np
15
+ from einops import rearrange
16
  import matplotlib.pyplot as plt
17
  from typing import Optional, Dict, Union, List, Tuple, Any
18
  from functools import partial
 
250
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
251
 
252
  class rotary(nn.Module):
253
+ def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False):
254
  super(rotary, self).__init__()
255
 
256
+ self.use_pbias = use_pbias
257
  self.dims = dims
258
  self.head = head
259
  self.head_dim = dims // head
 
 
 
260
  self.radii = radii
261
+ self.dim = self.head_dim
 
 
262
  self.debug = debug
263
  self.counter = 0
264
  self.last_theta = None
265
 
266
+ self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
267
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
268
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
269
  self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
 
284
  self.theta.data.copy_(theta)
285
 
286
  def get_pitch_bias(self, f0):
287
+ if f0 is None:
288
+ return None
289
  f0_flat = f0.squeeze().float()
290
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
291
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
 
315
  return f0[idx]
316
 
317
  def align_f0(self, ctx, f0):
 
 
318
  if f0.dim() == 3:
319
  batch, length, dims = f0.shape
320
  if length == ctx:
 
341
  return f0[idx, :]
342
 
343
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
 
 
 
 
 
 
 
 
 
 
344
  if isinstance(x, int):
345
  ctx = x
346
+ elif isinstance(x, torch.Tensor) and x.ndim == 2:
347
+ batch, ctx = x.shape
348
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
349
  batch, ctx, dims = x.shape
350
  else:
351
  batch, head, ctx, head_dim = x.shape
352
  t = torch.arange(ctx, device=device, dtype=dtype)
353
 
354
+ f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
355
+ if f0 is not None and f0.dim() == 2:
356
+ if f0.shape[0] == 1:
357
+ f0 = f0.squeeze(0)
358
+ else:
359
+ f0 = f0.view(-1)
360
+
361
  if f0 is not None:
362
  f0_mean = f0.mean()
363
  theta = f0_mean + self.theta
 
383
  radius = torch.ones_like(freqs)
384
  freqs = torch.polar(radius, freqs)
385
 
386
+ if "rot1" in self.debug and self.counter % 100 == 0:
387
+ print(f"Rotary forward: {x if x is not None else None}, f0: {f0.shape if f0 is not None else None}")
388
+
389
  if "rot3" in self.debug and self.counter % 100 == 0:
390
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
391
 
 
857
  ),
858
  "envelope": nn.ModuleList(
859
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
860
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None
 
861
  ),
862
  "phase": nn.ModuleList(
863
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
864
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None
 
865
  )
866
  })
867
 
 
895
 
896
  class TextDecoder(nn.Module):
897
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
898
+ debug: List[str], features: List[str]):
899
  super(TextDecoder, self).__init__()
900
 
901
  self.ctx = ctx
 
905
  self.debug = debug
906
  self.counter = 0
907
  self.dropout = 0.01
 
908
  self.features = features
909
 
910
  self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
 
926
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
927
  self.register_buffer("mask", mask, persistent=False)
928
 
929
+ def forward(self, x, enc, order=None, layer='decoder', sequential=False) -> Tensor:
930
  enc = dict_to(enc, device, dtype)
931
  x = x.to(device)
932
  bln = self.blend
 
938
  x = self.token(x) + self.positional[:x.shape[1]]
939
  x = F.dropout(x, p=self.dropout, training=self.training)
940
 
941
+ # ctx = x.shape[1]
942
+ # freqs = self.rotary(ctx)
943
+ # x = self.rotary.apply_rotary(x, freqs)
944
+
945
  for block in self.block:
946
  x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
947
 
948
  for f in order:
949
  if f in enc:
950
+ seq = x
951
  xa = enc[f]
952
  for block in self.blocks[f]:
953
  out = block(x=x, xa=xa, mask=None, enc=enc, layer=layer)
954
 
955
+ if sequential:
956
+ x = seq
957
+ else:
958
+ a = torch.sigmoid(bln[f])
959
+ x = a * out + (1 - a) * x
960
 
961
  if "decoder" in self.debug and self.counter % 100 == 0:
962
  print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
 
1022
  return self.decoder(input_ids, encoder_output)
1023
 
1024
  def forward(self,
 
1025
  labels=None,
1026
  waveform: Optional[torch.Tensor]=None,
1027
  input_ids=None,
1028
  spectrogram: torch.Tensor=None,
1029
  pitch: Optional[torch.Tensor]=None,
1030
  f0: Optional[torch.Tensor]=None,
 
1031
  envelope: Optional[torch.Tensor]=None,
1032
  phase: Optional[torch.Tensor]=None,
1033
  ) -> Dict[str, torch.Tensor]:
1034
 
 
1035
  encoder_inputs = {}
1036
  if spectrogram is not None:
1037
  encoder_inputs["spectrogram"] = spectrogram
 
1046
  if f0 is not None:
1047
  encoder_inputs["f0"] = f0
1048
 
 
 
 
 
 
1049
  encoder_outputs = self.encoder(encoder_inputs)
1050
  logits = self.decoder(input_ids, encoder_outputs)
1051
 
 
1058
  return {
1059
  "logits": logits,
1060
  "loss": loss,
 
 
 
 
1061
  }
1062
 
1063
  @property
 
1608
  per_device_train_batch_size=1,
1609
  per_device_eval_batch_size=1,
1610
  gradient_accumulation_steps=1,
1611
+ eval_accumulation_steps=None,
1612
  eval_strategy="steps",
1613
+ save_strategy="no",
1614
  max_steps=max_steps,
1615
  save_steps=save_steps,
1616
  eval_steps=eval_steps,
 
1694
  training_args = sanity(sanity_check)
1695
  dataset_config = {
1696
  "spectrogram": True,
1697
+ "waveforms": True,
1698
  "pitch": False,
1699
  "downsamples": False,
1700
  "frequency": False,
 
1743
  if __name__ == "__main__":
1744
  main()
1745
 
1746
+ # from tensorboard import program
1747
+ # log_dir = "./output/logs"
1748
+ # tb = program.TensorBoard()
1749
+ # tb.configure(argv=[None, '--logdir', log_dir])
1750
+ # url = tb.launch()
1751
+ # print(f"TensorBoard started at {url}")
1752