Update modelA.py
Browse files
modelA.py
CHANGED
@@ -984,6 +984,8 @@ class Echo(nn.Module):
|
|
984 |
"eos_token_id": self.eos_token_id,
|
985 |
})
|
986 |
return Config()
|
|
|
|
|
987 |
def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
988 |
from tokenizers import Tokenizer
|
989 |
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
|
@@ -994,6 +996,7 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
|
994 |
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
|
995 |
ids = [id for id in ids if id not in sp_ids]
|
996 |
return ids
|
|
|
997 |
def bdec(ids_list, skip_special_tokens=True):
|
998 |
results = []
|
999 |
for ids in ids_list:
|
@@ -1004,8 +1007,13 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
|
1004 |
ids = ids[:-1]
|
1005 |
results.append(tokenizer.decode(ids))
|
1006 |
return results
|
|
|
|
|
|
|
|
|
1007 |
tokenizer.encode = enc
|
1008 |
tokenizer.batch_decode = bdec
|
|
|
1009 |
tokenizer.pad_token_id = 0
|
1010 |
tokenizer.bos_token_id = 1
|
1011 |
tokenizer.eos_token_id = 2
|
@@ -1016,22 +1024,7 @@ def extract_features(batch, tokenizer, sample_rate=16000, n_mels=128, n_fft=1024
|
|
1016 |
waveform = torch.tensor(audio["array"]).float()
|
1017 |
if waveform.dim() == 2:
|
1018 |
waveform = waveform.mean(dim=0)
|
1019 |
-
|
1020 |
-
# transform = torchaudio.transforms.MelSpectrogram(
|
1021 |
-
# f_max=fmax,
|
1022 |
-
# f_min=fmin,
|
1023 |
-
# n_mels=n_mels,
|
1024 |
-
# sample_rate=sr,
|
1025 |
-
# n_fft=n_fft,
|
1026 |
-
# hop_length=hop_length,
|
1027 |
-
# norm=norm,
|
1028 |
-
# normalized=normalized,
|
1029 |
-
# power=power,
|
1030 |
-
# center=center,
|
1031 |
-
# mel_scale=mel_scale,
|
1032 |
-
# window_fn=window_fn,
|
1033 |
-
# pad_mode=pad_mode)
|
1034 |
-
|
1035 |
# mel_spectrogram = transform(wav)
|
1036 |
# log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
1037 |
# log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
@@ -1099,7 +1092,6 @@ class DataCollator:
|
|
1099 |
torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
|
1100 |
])
|
1101 |
|
1102 |
-
# Gather and pad input_ids/labels
|
1103 |
input_ids_list = [f["input_ids"] for f in features]
|
1104 |
# Ensure all are lists, not tensors
|
1105 |
input_ids_list = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids_list]
|
|
|
984 |
"eos_token_id": self.eos_token_id,
|
985 |
})
|
986 |
return Config()
|
987 |
+
|
988 |
+
|
989 |
def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
990 |
from tokenizers import Tokenizer
|
991 |
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
|
|
|
996 |
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
|
997 |
ids = [id for id in ids if id not in sp_ids]
|
998 |
return ids
|
999 |
+
|
1000 |
def bdec(ids_list, skip_special_tokens=True):
|
1001 |
results = []
|
1002 |
for ids in ids_list:
|
|
|
1007 |
ids = ids[:-1]
|
1008 |
results.append(tokenizer.decode(ids))
|
1009 |
return results
|
1010 |
+
|
1011 |
+
def save_pretrained(save_dir):
|
1012 |
+
os.makedirs(save_dir, exist_ok=True)
|
1013 |
+
tokenizer.save(f"{save_dir}/tokenizer.json")
|
1014 |
tokenizer.encode = enc
|
1015 |
tokenizer.batch_decode = bdec
|
1016 |
+
tokenizer.save_pretrained = save_pretrained
|
1017 |
tokenizer.pad_token_id = 0
|
1018 |
tokenizer.bos_token_id = 1
|
1019 |
tokenizer.eos_token_id = 2
|
|
|
1024 |
waveform = torch.tensor(audio["array"]).float()
|
1025 |
if waveform.dim() == 2:
|
1026 |
waveform = waveform.mean(dim=0)
|
1027 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1028 |
# mel_spectrogram = transform(wav)
|
1029 |
# log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
1030 |
# log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
|
|
1092 |
torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
|
1093 |
])
|
1094 |
|
|
|
1095 |
input_ids_list = [f["input_ids"] for f in features]
|
1096 |
# Ensure all are lists, not tensors
|
1097 |
input_ids_list = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids_list]
|