Update modelA.py
Browse files
modelA.py
CHANGED
@@ -984,7 +984,6 @@ class Echo(nn.Module):
|
|
984 |
"eos_token_id": self.eos_token_id,
|
985 |
})
|
986 |
return Config()
|
987 |
-
|
988 |
def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
989 |
from tokenizers import Tokenizer
|
990 |
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
|
@@ -1017,10 +1016,33 @@ def extract_features(batch, tokenizer, sample_rate=16000, n_mels=128, n_fft=1024
|
|
1017 |
waveform = torch.tensor(audio["array"]).float()
|
1018 |
if waveform.dim() == 2:
|
1019 |
waveform = waveform.mean(dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1020 |
mel = torchaudio.transforms.MelSpectrogram(
|
1021 |
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
|
1022 |
)
|
1023 |
spec = mel(waveform)
|
|
|
1024 |
spec = torch.tensor(spec) if not isinstance(spec, torch.Tensor) else spec
|
1025 |
wav_np = waveform.numpy().astype(np.float64)
|
1026 |
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
|
@@ -1056,13 +1078,18 @@ def prepare_datasets(tokenizer, token: str, sample_rate=16000, n_mels=128, n_fft
|
|
1056 |
|
1057 |
@dataclass
|
1058 |
class DataCollator:
|
1059 |
-
tokenizer:
|
1060 |
-
|
|
|
1061 |
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
|
|
|
|
|
|
|
|
|
1062 |
specs = [f["spectrogram"] for f in features]
|
1063 |
f0s = [f["f0"] for f in features]
|
1064 |
-
specs = [torch.tensor(
|
1065 |
-
f0s = [torch.tensor(
|
1066 |
max_spec_len = max(s.shape[-1] for s in specs)
|
1067 |
max_f0_len = max(f0.shape[-1] for f0 in f0s)
|
1068 |
padded_specs = torch.stack([
|
@@ -1071,11 +1098,18 @@ class DataCollator:
|
|
1071 |
padded_f0s = torch.stack([
|
1072 |
torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
|
1073 |
])
|
1074 |
-
|
1075 |
-
|
1076 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1077 |
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
1078 |
-
labels =
|
|
|
1079 |
return {
|
1080 |
"spectrogram": padded_specs,
|
1081 |
"f0": padded_f0s,
|
@@ -1201,5 +1235,4 @@ def main():
|
|
1201 |
trainer.train()
|
1202 |
|
1203 |
if __name__ == "__main__":
|
1204 |
-
main()
|
1205 |
-
|
|
|
984 |
"eos_token_id": self.eos_token_id,
|
985 |
})
|
986 |
return Config()
|
|
|
987 |
def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
988 |
from tokenizers import Tokenizer
|
989 |
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
|
|
|
1016 |
waveform = torch.tensor(audio["array"]).float()
|
1017 |
if waveform.dim() == 2:
|
1018 |
waveform = waveform.mean(dim=0)
|
1019 |
+
|
1020 |
+
# transform = torchaudio.transforms.MelSpectrogram(
|
1021 |
+
# f_max=fmax,
|
1022 |
+
# f_min=fmin,
|
1023 |
+
# n_mels=n_mels,
|
1024 |
+
# sample_rate=sr,
|
1025 |
+
# n_fft=n_fft,
|
1026 |
+
# hop_length=hop_length,
|
1027 |
+
# norm=norm,
|
1028 |
+
# normalized=normalized,
|
1029 |
+
# power=power,
|
1030 |
+
# center=center,
|
1031 |
+
# mel_scale=mel_scale,
|
1032 |
+
# window_fn=window_fn,
|
1033 |
+
# pad_mode=pad_mode)
|
1034 |
+
|
1035 |
+
# mel_spectrogram = transform(wav)
|
1036 |
+
# log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
1037 |
+
# log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1038 |
+
# spec = (log_mel + 4.0) / 4.0
|
1039 |
+
# spec = torch.tensor(spec)
|
1040 |
+
|
1041 |
mel = torchaudio.transforms.MelSpectrogram(
|
1042 |
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
|
1043 |
)
|
1044 |
spec = mel(waveform)
|
1045 |
+
spec = torch.clamp(spec, min=1e-10).log10()
|
1046 |
spec = torch.tensor(spec) if not isinstance(spec, torch.Tensor) else spec
|
1047 |
wav_np = waveform.numpy().astype(np.float64)
|
1048 |
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
|
|
|
1078 |
|
1079 |
@dataclass
|
1080 |
class DataCollator:
|
1081 |
+
tokenizer: Any
|
1082 |
+
|
1083 |
+
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
1084 |
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
|
1085 |
+
bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
|
1086 |
+
eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
|
1087 |
+
|
1088 |
+
# Gather and pad spectrograms and f0
|
1089 |
specs = [f["spectrogram"] for f in features]
|
1090 |
f0s = [f["f0"] for f in features]
|
1091 |
+
specs = [torch.tensor(s) if not isinstance(s, torch.Tensor) else s for s in specs]
|
1092 |
+
f0s = [torch.tensor(f0) if not isinstance(f0, torch.Tensor) else f0 for f0 in f0s]
|
1093 |
max_spec_len = max(s.shape[-1] for s in specs)
|
1094 |
max_f0_len = max(f0.shape[-1] for f0 in f0s)
|
1095 |
padded_specs = torch.stack([
|
|
|
1098 |
padded_f0s = torch.stack([
|
1099 |
torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
|
1100 |
])
|
1101 |
+
|
1102 |
+
# Gather and pad input_ids/labels
|
1103 |
+
input_ids_list = [f["input_ids"] for f in features]
|
1104 |
+
# Ensure all are lists, not tensors
|
1105 |
+
input_ids_list = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids_list]
|
1106 |
+
max_len = max(len(ids) for ids in input_ids_list)
|
1107 |
+
# Add BOS to input_ids, EOS to labels, pad both to max_len+1
|
1108 |
+
input_ids = [[bos_token_id] + ids + [pad_token_id] * (max_len - len(ids)) for ids in input_ids_list]
|
1109 |
+
labels = [ids + [eos_token_id] + [pad_token_id] * (max_len - len(ids)) for ids in input_ids_list]
|
1110 |
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
1111 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
1112 |
+
|
1113 |
return {
|
1114 |
"spectrogram": padded_specs,
|
1115 |
"f0": padded_f0s,
|
|
|
1235 |
trainer.train()
|
1236 |
|
1237 |
if __name__ == "__main__":
|
1238 |
+
main()
|
|