Sin2pi commited on
Commit
00f642e
·
verified ·
1 Parent(s): d31c315

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +9 -17
modelA.py CHANGED
@@ -984,6 +984,8 @@ class Echo(nn.Module):
984
  "eos_token_id": self.eos_token_id,
985
  })
986
  return Config()
 
 
987
  def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
988
  from tokenizers import Tokenizer
989
  tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
@@ -994,6 +996,7 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
994
  sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
995
  ids = [id for id in ids if id not in sp_ids]
996
  return ids
 
997
  def bdec(ids_list, skip_special_tokens=True):
998
  results = []
999
  for ids in ids_list:
@@ -1004,8 +1007,13 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
1004
  ids = ids[:-1]
1005
  results.append(tokenizer.decode(ids))
1006
  return results
 
 
 
 
1007
  tokenizer.encode = enc
1008
  tokenizer.batch_decode = bdec
 
1009
  tokenizer.pad_token_id = 0
1010
  tokenizer.bos_token_id = 1
1011
  tokenizer.eos_token_id = 2
@@ -1016,22 +1024,7 @@ def extract_features(batch, tokenizer, sample_rate=16000, n_mels=128, n_fft=1024
1016
  waveform = torch.tensor(audio["array"]).float()
1017
  if waveform.dim() == 2:
1018
  waveform = waveform.mean(dim=0)
1019
-
1020
- # transform = torchaudio.transforms.MelSpectrogram(
1021
- # f_max=fmax,
1022
- # f_min=fmin,
1023
- # n_mels=n_mels,
1024
- # sample_rate=sr,
1025
- # n_fft=n_fft,
1026
- # hop_length=hop_length,
1027
- # norm=norm,
1028
- # normalized=normalized,
1029
- # power=power,
1030
- # center=center,
1031
- # mel_scale=mel_scale,
1032
- # window_fn=window_fn,
1033
- # pad_mode=pad_mode)
1034
-
1035
  # mel_spectrogram = transform(wav)
1036
  # log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1037
  # log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
@@ -1099,7 +1092,6 @@ class DataCollator:
1099
  torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
1100
  ])
1101
 
1102
- # Gather and pad input_ids/labels
1103
  input_ids_list = [f["input_ids"] for f in features]
1104
  # Ensure all are lists, not tensors
1105
  input_ids_list = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids_list]
 
984
  "eos_token_id": self.eos_token_id,
985
  })
986
  return Config()
987
+
988
+
989
  def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
990
  from tokenizers import Tokenizer
991
  tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
 
996
  sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
997
  ids = [id for id in ids if id not in sp_ids]
998
  return ids
999
+
1000
  def bdec(ids_list, skip_special_tokens=True):
1001
  results = []
1002
  for ids in ids_list:
 
1007
  ids = ids[:-1]
1008
  results.append(tokenizer.decode(ids))
1009
  return results
1010
+
1011
+ def save_pretrained(save_dir):
1012
+ os.makedirs(save_dir, exist_ok=True)
1013
+ tokenizer.save(f"{save_dir}/tokenizer.json")
1014
  tokenizer.encode = enc
1015
  tokenizer.batch_decode = bdec
1016
+ tokenizer.save_pretrained = save_pretrained
1017
  tokenizer.pad_token_id = 0
1018
  tokenizer.bos_token_id = 1
1019
  tokenizer.eos_token_id = 2
 
1024
  waveform = torch.tensor(audio["array"]).float()
1025
  if waveform.dim() == 2:
1026
  waveform = waveform.mean(dim=0)
1027
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1028
  # mel_spectrogram = transform(wav)
1029
  # log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1030
  # log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
 
1092
  torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
1093
  ])
1094
 
 
1095
  input_ids_list = [f["input_ids"] for f in features]
1096
  # Ensure all are lists, not tensors
1097
  input_ids_list = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids_list]