SWivid commited on
Commit
1489cdc
Β·
1 Parent(s): 23a1101

final structure. prepared to solve dependencies

Browse files
src/f5_tts/{model β†’ eval}/ecapa_tdnn.py RENAMED
File without changes
src/f5_tts/eval/utils_eval.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import string
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+
11
+ from f5_tts.model.modules import MelSpec
12
+ from f5_tts.model.utils import convert_char_to_pinyin
13
+ from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
14
+
15
+
16
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
17
+ def get_seedtts_testset_metainfo(metalst):
18
+ f = open(metalst)
19
+ lines = f.readlines()
20
+ f.close()
21
+ metainfo = []
22
+ for line in lines:
23
+ if len(line.strip().split("|")) == 5:
24
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
25
+ elif len(line.strip().split("|")) == 4:
26
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
27
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
28
+ if not os.path.isabs(prompt_wav):
29
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
30
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
31
+ return metainfo
32
+
33
+
34
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
35
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
36
+ f = open(metalst)
37
+ lines = f.readlines()
38
+ f.close()
39
+ metainfo = []
40
+ for line in lines:
41
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
42
+
43
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
44
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
45
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
46
+
47
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
48
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
49
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
50
+
51
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
52
+
53
+ return metainfo
54
+
55
+
56
+ # padded to max length mel batch
57
+ def padded_mel_batch(ref_mels):
58
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
59
+ padded_ref_mels = []
60
+ for mel in ref_mels:
61
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
62
+ padded_ref_mels.append(padded_ref_mel)
63
+ padded_ref_mels = torch.stack(padded_ref_mels)
64
+ padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
65
+ return padded_ref_mels
66
+
67
+
68
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
69
+
70
+
71
+ def get_inference_prompt(
72
+ metainfo,
73
+ speed=1.0,
74
+ tokenizer="pinyin",
75
+ polyphone=True,
76
+ target_sample_rate=24000,
77
+ n_mel_channels=100,
78
+ hop_length=256,
79
+ target_rms=0.1,
80
+ use_truth_duration=False,
81
+ infer_batch_size=1,
82
+ num_buckets=200,
83
+ min_secs=3,
84
+ max_secs=40,
85
+ ):
86
+ prompts_all = []
87
+
88
+ min_tokens = min_secs * target_sample_rate // hop_length
89
+ max_tokens = max_secs * target_sample_rate // hop_length
90
+
91
+ batch_accum = [0] * num_buckets
92
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
93
+ [[] for _ in range(num_buckets)] for _ in range(6)
94
+ )
95
+
96
+ mel_spectrogram = MelSpec(
97
+ target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
98
+ )
99
+
100
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
101
+ # Audio
102
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
103
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
104
+ if ref_rms < target_rms:
105
+ ref_audio = ref_audio * target_rms / ref_rms
106
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
107
+ if ref_sr != target_sample_rate:
108
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
109
+ ref_audio = resampler(ref_audio)
110
+
111
+ # Text
112
+ if len(prompt_text[-1].encode("utf-8")) == 1:
113
+ prompt_text = prompt_text + " "
114
+ text = [prompt_text + gt_text]
115
+ if tokenizer == "pinyin":
116
+ text_list = convert_char_to_pinyin(text, polyphone=polyphone)
117
+ else:
118
+ text_list = text
119
+
120
+ # Duration, mel frame length
121
+ ref_mel_len = ref_audio.shape[-1] // hop_length
122
+ if use_truth_duration:
123
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
124
+ if gt_sr != target_sample_rate:
125
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
126
+ gt_audio = resampler(gt_audio)
127
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
128
+
129
+ # # test vocoder resynthesis
130
+ # ref_audio = gt_audio
131
+ else:
132
+ ref_text_len = len(prompt_text.encode("utf-8"))
133
+ gen_text_len = len(gt_text.encode("utf-8"))
134
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
135
+
136
+ # to mel spectrogram
137
+ ref_mel = mel_spectrogram(ref_audio)
138
+ ref_mel = ref_mel.squeeze(0)
139
+
140
+ # deal with batch
141
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
142
+ assert (
143
+ min_tokens <= total_mel_len <= max_tokens
144
+ ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
145
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
146
+
147
+ utts[bucket_i].append(utt)
148
+ ref_rms_list[bucket_i].append(ref_rms)
149
+ ref_mels[bucket_i].append(ref_mel)
150
+ ref_mel_lens[bucket_i].append(ref_mel_len)
151
+ total_mel_lens[bucket_i].append(total_mel_len)
152
+ final_text_list[bucket_i].extend(text_list)
153
+
154
+ batch_accum[bucket_i] += total_mel_len
155
+
156
+ if batch_accum[bucket_i] >= infer_batch_size:
157
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
158
+ prompts_all.append(
159
+ (
160
+ utts[bucket_i],
161
+ ref_rms_list[bucket_i],
162
+ padded_mel_batch(ref_mels[bucket_i]),
163
+ ref_mel_lens[bucket_i],
164
+ total_mel_lens[bucket_i],
165
+ final_text_list[bucket_i],
166
+ )
167
+ )
168
+ batch_accum[bucket_i] = 0
169
+ (
170
+ utts[bucket_i],
171
+ ref_rms_list[bucket_i],
172
+ ref_mels[bucket_i],
173
+ ref_mel_lens[bucket_i],
174
+ total_mel_lens[bucket_i],
175
+ final_text_list[bucket_i],
176
+ ) = [], [], [], [], [], []
177
+
178
+ # add residual
179
+ for bucket_i, bucket_frames in enumerate(batch_accum):
180
+ if bucket_frames > 0:
181
+ prompts_all.append(
182
+ (
183
+ utts[bucket_i],
184
+ ref_rms_list[bucket_i],
185
+ padded_mel_batch(ref_mels[bucket_i]),
186
+ ref_mel_lens[bucket_i],
187
+ total_mel_lens[bucket_i],
188
+ final_text_list[bucket_i],
189
+ )
190
+ )
191
+ # not only leave easy work for last workers
192
+ random.seed(666)
193
+ random.shuffle(prompts_all)
194
+
195
+ return prompts_all
196
+
197
+
198
+ # get wav_res_ref_text of seed-tts test metalst
199
+ # https://github.com/BytedanceSpeech/seed-tts-eval
200
+
201
+
202
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
203
+ f = open(metalst)
204
+ lines = f.readlines()
205
+ f.close()
206
+
207
+ test_set_ = []
208
+ for line in tqdm(lines):
209
+ if len(line.strip().split("|")) == 5:
210
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
211
+ elif len(line.strip().split("|")) == 4:
212
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
213
+
214
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
215
+ continue
216
+ gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
217
+ if not os.path.isabs(prompt_wav):
218
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
219
+
220
+ test_set_.append((gen_wav, prompt_wav, gt_text))
221
+
222
+ num_jobs = len(gpus)
223
+ if num_jobs == 1:
224
+ return [(gpus[0], test_set_)]
225
+
226
+ wav_per_job = len(test_set_) // num_jobs + 1
227
+ test_set = []
228
+ for i in range(num_jobs):
229
+ test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
230
+
231
+ return test_set
232
+
233
+
234
+ # get librispeech test-clean cross sentence test
235
+
236
+
237
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
238
+ f = open(metalst)
239
+ lines = f.readlines()
240
+ f.close()
241
+
242
+ test_set_ = []
243
+ for line in tqdm(lines):
244
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
245
+
246
+ if eval_ground_truth:
247
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
248
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
249
+ else:
250
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
251
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
252
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
253
+
254
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
255
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
256
+
257
+ test_set_.append((gen_wav, ref_wav, gen_txt))
258
+
259
+ num_jobs = len(gpus)
260
+ if num_jobs == 1:
261
+ return [(gpus[0], test_set_)]
262
+
263
+ wav_per_job = len(test_set_) // num_jobs + 1
264
+ test_set = []
265
+ for i in range(num_jobs):
266
+ test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
267
+
268
+ return test_set
269
+
270
+
271
+ # load asr model
272
+
273
+
274
+ def load_asr_model(lang, ckpt_dir=""):
275
+ if lang == "zh":
276
+ from funasr import AutoModel
277
+
278
+ model = AutoModel(
279
+ model=os.path.join(ckpt_dir, "paraformer-zh"),
280
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
281
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
282
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
283
+ disable_update=True,
284
+ ) # following seed-tts setting
285
+ elif lang == "en":
286
+ from faster_whisper import WhisperModel
287
+
288
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
289
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
290
+ return model
291
+
292
+
293
+ # WER Evaluation, the way Seed-TTS does
294
+
295
+
296
+ def run_asr_wer(args):
297
+ rank, lang, test_set, ckpt_dir = args
298
+
299
+ if lang == "zh":
300
+ import zhconv
301
+
302
+ torch.cuda.set_device(rank)
303
+ elif lang == "en":
304
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
305
+ else:
306
+ raise NotImplementedError(
307
+ "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
308
+ )
309
+
310
+ asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
311
+
312
+ from zhon.hanzi import punctuation
313
+
314
+ punctuation_all = punctuation + string.punctuation
315
+ wers = []
316
+
317
+ from jiwer import compute_measures
318
+
319
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
320
+ if lang == "zh":
321
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
322
+ hypo = res[0]["text"]
323
+ hypo = zhconv.convert(hypo, "zh-cn")
324
+ elif lang == "en":
325
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
326
+ hypo = ""
327
+ for segment in segments:
328
+ hypo = hypo + " " + segment.text
329
+
330
+ # raw_truth = truth
331
+ # raw_hypo = hypo
332
+
333
+ for x in punctuation_all:
334
+ truth = truth.replace(x, "")
335
+ hypo = hypo.replace(x, "")
336
+
337
+ truth = truth.replace(" ", " ")
338
+ hypo = hypo.replace(" ", " ")
339
+
340
+ if lang == "zh":
341
+ truth = " ".join([x for x in truth])
342
+ hypo = " ".join([x for x in hypo])
343
+ elif lang == "en":
344
+ truth = truth.lower()
345
+ hypo = hypo.lower()
346
+
347
+ measures = compute_measures(truth, hypo)
348
+ wer = measures["wer"]
349
+
350
+ # ref_list = truth.split(" ")
351
+ # subs = measures["substitutions"] / len(ref_list)
352
+ # dele = measures["deletions"] / len(ref_list)
353
+ # inse = measures["insertions"] / len(ref_list)
354
+
355
+ wers.append(wer)
356
+
357
+ return wers
358
+
359
+
360
+ # SIM Evaluation
361
+
362
+
363
+ def run_sim(args):
364
+ rank, test_set, ckpt_dir = args
365
+ device = f"cuda:{rank}"
366
+
367
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
368
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
369
+ model.load_state_dict(state_dict["model"], strict=False)
370
+
371
+ use_gpu = True if torch.cuda.is_available() else False
372
+ if use_gpu:
373
+ model = model.cuda(device)
374
+ model.eval()
375
+
376
+ sim_list = []
377
+ for wav1, wav2, truth in tqdm(test_set):
378
+ wav1, sr1 = torchaudio.load(wav1)
379
+ wav2, sr2 = torchaudio.load(wav2)
380
+
381
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
382
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
383
+ wav1 = resample1(wav1)
384
+ wav2 = resample2(wav2)
385
+
386
+ if use_gpu:
387
+ wav1 = wav1.cuda(device)
388
+ wav2 = wav2.cuda(device)
389
+ with torch.no_grad():
390
+ emb1 = model(wav1)
391
+ emb2 = model(wav2)
392
+
393
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
394
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
395
+ sim_list.append(sim)
396
+
397
+ return sim_list
src/f5_tts/infer/infer_cli.py CHANGED
@@ -11,7 +11,7 @@ import tomli
11
  from cached_path import cached_path
12
 
13
  from f5_tts.model import DiT, UNetT
14
- from f5_tts.model.utils_infer import (
15
  load_vocoder,
16
  load_model,
17
  preprocess_ref_audio_text,
 
11
  from cached_path import cached_path
12
 
13
  from f5_tts.model import DiT, UNetT
14
+ from f5_tts.infer.utils_infer import (
15
  load_vocoder,
16
  load_model,
17
  preprocess_ref_audio_text,
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -28,15 +28,13 @@ def gpu_decorator(func):
28
 
29
 
30
  from f5_tts.model import DiT, UNetT
31
- from f5_tts.model.utils import (
32
- save_spectrogram,
33
- )
34
- from f5_tts.model.utils_infer import (
35
  load_vocoder,
36
  load_model,
37
  preprocess_ref_audio_text,
38
  infer_process,
39
  remove_silence_for_generated_wav,
 
40
  )
41
 
42
  vocos = load_vocoder()
 
28
 
29
 
30
  from f5_tts.model import DiT, UNetT
31
+ from f5_tts.infer.utils_infer import (
 
 
 
32
  load_vocoder,
33
  load_model,
34
  preprocess_ref_audio_text,
35
  infer_process,
36
  remove_silence_for_generated_wav,
37
+ save_spectrogram,
38
  )
39
 
40
  vocos = load_vocoder()
src/f5_tts/infer/speech_edit.py CHANGED
@@ -10,8 +10,8 @@ from f5_tts.model.utils import (
10
  load_checkpoint,
11
  get_tokenizer,
12
  convert_char_to_pinyin,
13
- save_spectrogram,
14
  )
 
15
 
16
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
17
 
 
10
  load_checkpoint,
11
  get_tokenizer,
12
  convert_char_to_pinyin,
 
13
  )
14
+ from f5_tts.infer.utils_infer import save_spectrogram
15
 
16
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
17
 
src/f5_tts/{model β†’ infer}/utils_infer.py RENAMED
@@ -4,6 +4,11 @@
4
  import re
5
  import tempfile
6
 
 
 
 
 
 
7
  import numpy as np
8
  import torch
9
  import torchaudio
@@ -14,7 +19,6 @@ from vocos import Vocos
14
 
15
  from f5_tts.model import CFM
16
  from f5_tts.model.utils import (
17
- load_checkpoint,
18
  get_tokenizer,
19
  convert_char_to_pinyin,
20
  )
@@ -104,6 +108,38 @@ def initialize_asr_pipeline(device=device):
104
  )
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # load model for inference
108
 
109
 
@@ -355,3 +391,14 @@ def remove_silence_for_generated_wav(filename):
355
  non_silent_wave += non_silent_seg
356
  aseg = non_silent_wave
357
  aseg.export(filename, format="wav")
 
 
 
 
 
 
 
 
 
 
 
 
4
  import re
5
  import tempfile
6
 
7
+ import matplotlib
8
+
9
+ matplotlib.use("Agg")
10
+
11
+ import matplotlib.pylab as plt
12
  import numpy as np
13
  import torch
14
  import torchaudio
 
19
 
20
  from f5_tts.model import CFM
21
  from f5_tts.model.utils import (
 
22
  get_tokenizer,
23
  convert_char_to_pinyin,
24
  )
 
108
  )
109
 
110
 
111
+ # load model checkpoint for inference
112
+
113
+
114
+ def load_checkpoint(model, ckpt_path, device, use_ema=True):
115
+ if device == "cuda":
116
+ model = model.half()
117
+
118
+ ckpt_type = ckpt_path.split(".")[-1]
119
+ if ckpt_type == "safetensors":
120
+ from safetensors.torch import load_file
121
+
122
+ checkpoint = load_file(ckpt_path)
123
+ else:
124
+ checkpoint = torch.load(ckpt_path, weights_only=True)
125
+
126
+ if use_ema:
127
+ if ckpt_type == "safetensors":
128
+ checkpoint = {"ema_model_state_dict": checkpoint}
129
+ checkpoint["model_state_dict"] = {
130
+ k.replace("ema_model.", ""): v
131
+ for k, v in checkpoint["ema_model_state_dict"].items()
132
+ if k not in ["initted", "step"]
133
+ }
134
+ model.load_state_dict(checkpoint["model_state_dict"])
135
+ else:
136
+ if ckpt_type == "safetensors":
137
+ checkpoint = {"model_state_dict": checkpoint}
138
+ model.load_state_dict(checkpoint["model_state_dict"])
139
+
140
+ return model.to(device)
141
+
142
+
143
  # load model for inference
144
 
145
 
 
391
  non_silent_wave += non_silent_seg
392
  aseg = non_silent_wave
393
  aseg.export(filename, format="wav")
394
+
395
+
396
+ # save spectrogram
397
+
398
+
399
+ def save_spectrogram(spectrogram, path):
400
+ plt.figure(figsize=(12, 4))
401
+ plt.imshow(spectrogram, origin="lower", aspect="auto")
402
+ plt.colorbar()
403
+ plt.savefig(path)
404
+ plt.close()
src/f5_tts/model/utils.py CHANGED
@@ -1,29 +1,16 @@
1
  from __future__ import annotations
2
 
3
  import os
4
- import math
5
  import random
6
- import string
7
  from importlib.resources import files
8
- from tqdm import tqdm
9
  from collections import defaultdict
10
 
11
- import matplotlib
12
-
13
- matplotlib.use("Agg")
14
- import matplotlib.pylab as plt
15
-
16
  import torch
17
- import torch.nn.functional as F
18
  from torch.nn.utils.rnn import pad_sequence
19
- import torchaudio
20
 
21
  import jieba
22
  from pypinyin import lazy_pinyin, Style
23
 
24
- from f5_tts.model.ecapa_tdnn import ECAPA_TDNN_SMALL
25
- from f5_tts.model.modules import MelSpec
26
-
27
 
28
  # seed everything
29
 
@@ -183,399 +170,6 @@ def convert_char_to_pinyin(text_list, polyphone=True):
183
  return final_text_list
184
 
185
 
186
- # save spectrogram
187
- def save_spectrogram(spectrogram, path):
188
- plt.figure(figsize=(12, 4))
189
- plt.imshow(spectrogram, origin="lower", aspect="auto")
190
- plt.colorbar()
191
- plt.savefig(path)
192
- plt.close()
193
-
194
-
195
- # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
196
- def get_seedtts_testset_metainfo(metalst):
197
- f = open(metalst)
198
- lines = f.readlines()
199
- f.close()
200
- metainfo = []
201
- for line in lines:
202
- if len(line.strip().split("|")) == 5:
203
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
204
- elif len(line.strip().split("|")) == 4:
205
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
206
- gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
207
- if not os.path.isabs(prompt_wav):
208
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
209
- metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
210
- return metainfo
211
-
212
-
213
- # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
214
- def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
215
- f = open(metalst)
216
- lines = f.readlines()
217
- f.close()
218
- metainfo = []
219
- for line in lines:
220
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
221
-
222
- # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
223
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
224
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
225
-
226
- # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
227
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
228
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
229
-
230
- metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
231
-
232
- return metainfo
233
-
234
-
235
- # padded to max length mel batch
236
- def padded_mel_batch(ref_mels):
237
- max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
238
- padded_ref_mels = []
239
- for mel in ref_mels:
240
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
241
- padded_ref_mels.append(padded_ref_mel)
242
- padded_ref_mels = torch.stack(padded_ref_mels)
243
- padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
244
- return padded_ref_mels
245
-
246
-
247
- # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
248
-
249
-
250
- def get_inference_prompt(
251
- metainfo,
252
- speed=1.0,
253
- tokenizer="pinyin",
254
- polyphone=True,
255
- target_sample_rate=24000,
256
- n_mel_channels=100,
257
- hop_length=256,
258
- target_rms=0.1,
259
- use_truth_duration=False,
260
- infer_batch_size=1,
261
- num_buckets=200,
262
- min_secs=3,
263
- max_secs=40,
264
- ):
265
- prompts_all = []
266
-
267
- min_tokens = min_secs * target_sample_rate // hop_length
268
- max_tokens = max_secs * target_sample_rate // hop_length
269
-
270
- batch_accum = [0] * num_buckets
271
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
272
- [[] for _ in range(num_buckets)] for _ in range(6)
273
- )
274
-
275
- mel_spectrogram = MelSpec(
276
- target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
277
- )
278
-
279
- for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
280
- # Audio
281
- ref_audio, ref_sr = torchaudio.load(prompt_wav)
282
- ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
283
- if ref_rms < target_rms:
284
- ref_audio = ref_audio * target_rms / ref_rms
285
- assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
286
- if ref_sr != target_sample_rate:
287
- resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
288
- ref_audio = resampler(ref_audio)
289
-
290
- # Text
291
- if len(prompt_text[-1].encode("utf-8")) == 1:
292
- prompt_text = prompt_text + " "
293
- text = [prompt_text + gt_text]
294
- if tokenizer == "pinyin":
295
- text_list = convert_char_to_pinyin(text, polyphone=polyphone)
296
- else:
297
- text_list = text
298
-
299
- # Duration, mel frame length
300
- ref_mel_len = ref_audio.shape[-1] // hop_length
301
- if use_truth_duration:
302
- gt_audio, gt_sr = torchaudio.load(gt_wav)
303
- if gt_sr != target_sample_rate:
304
- resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
305
- gt_audio = resampler(gt_audio)
306
- total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
307
-
308
- # # test vocoder resynthesis
309
- # ref_audio = gt_audio
310
- else:
311
- ref_text_len = len(prompt_text.encode("utf-8"))
312
- gen_text_len = len(gt_text.encode("utf-8"))
313
- total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
314
-
315
- # to mel spectrogram
316
- ref_mel = mel_spectrogram(ref_audio)
317
- ref_mel = ref_mel.squeeze(0)
318
-
319
- # deal with batch
320
- assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
321
- assert (
322
- min_tokens <= total_mel_len <= max_tokens
323
- ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
324
- bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
325
-
326
- utts[bucket_i].append(utt)
327
- ref_rms_list[bucket_i].append(ref_rms)
328
- ref_mels[bucket_i].append(ref_mel)
329
- ref_mel_lens[bucket_i].append(ref_mel_len)
330
- total_mel_lens[bucket_i].append(total_mel_len)
331
- final_text_list[bucket_i].extend(text_list)
332
-
333
- batch_accum[bucket_i] += total_mel_len
334
-
335
- if batch_accum[bucket_i] >= infer_batch_size:
336
- # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
337
- prompts_all.append(
338
- (
339
- utts[bucket_i],
340
- ref_rms_list[bucket_i],
341
- padded_mel_batch(ref_mels[bucket_i]),
342
- ref_mel_lens[bucket_i],
343
- total_mel_lens[bucket_i],
344
- final_text_list[bucket_i],
345
- )
346
- )
347
- batch_accum[bucket_i] = 0
348
- (
349
- utts[bucket_i],
350
- ref_rms_list[bucket_i],
351
- ref_mels[bucket_i],
352
- ref_mel_lens[bucket_i],
353
- total_mel_lens[bucket_i],
354
- final_text_list[bucket_i],
355
- ) = [], [], [], [], [], []
356
-
357
- # add residual
358
- for bucket_i, bucket_frames in enumerate(batch_accum):
359
- if bucket_frames > 0:
360
- prompts_all.append(
361
- (
362
- utts[bucket_i],
363
- ref_rms_list[bucket_i],
364
- padded_mel_batch(ref_mels[bucket_i]),
365
- ref_mel_lens[bucket_i],
366
- total_mel_lens[bucket_i],
367
- final_text_list[bucket_i],
368
- )
369
- )
370
- # not only leave easy work for last workers
371
- random.seed(666)
372
- random.shuffle(prompts_all)
373
-
374
- return prompts_all
375
-
376
-
377
- # get wav_res_ref_text of seed-tts test metalst
378
- # https://github.com/BytedanceSpeech/seed-tts-eval
379
-
380
-
381
- def get_seed_tts_test(metalst, gen_wav_dir, gpus):
382
- f = open(metalst)
383
- lines = f.readlines()
384
- f.close()
385
-
386
- test_set_ = []
387
- for line in tqdm(lines):
388
- if len(line.strip().split("|")) == 5:
389
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
390
- elif len(line.strip().split("|")) == 4:
391
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
392
-
393
- if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
394
- continue
395
- gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
396
- if not os.path.isabs(prompt_wav):
397
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
398
-
399
- test_set_.append((gen_wav, prompt_wav, gt_text))
400
-
401
- num_jobs = len(gpus)
402
- if num_jobs == 1:
403
- return [(gpus[0], test_set_)]
404
-
405
- wav_per_job = len(test_set_) // num_jobs + 1
406
- test_set = []
407
- for i in range(num_jobs):
408
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
409
-
410
- return test_set
411
-
412
-
413
- # get librispeech test-clean cross sentence test
414
-
415
-
416
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
417
- f = open(metalst)
418
- lines = f.readlines()
419
- f.close()
420
-
421
- test_set_ = []
422
- for line in tqdm(lines):
423
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
424
-
425
- if eval_ground_truth:
426
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
427
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
428
- else:
429
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
430
- raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
431
- gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
432
-
433
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
434
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
435
-
436
- test_set_.append((gen_wav, ref_wav, gen_txt))
437
-
438
- num_jobs = len(gpus)
439
- if num_jobs == 1:
440
- return [(gpus[0], test_set_)]
441
-
442
- wav_per_job = len(test_set_) // num_jobs + 1
443
- test_set = []
444
- for i in range(num_jobs):
445
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
446
-
447
- return test_set
448
-
449
-
450
- # load asr model
451
-
452
-
453
- def load_asr_model(lang, ckpt_dir=""):
454
- if lang == "zh":
455
- from funasr import AutoModel
456
-
457
- model = AutoModel(
458
- model=os.path.join(ckpt_dir, "paraformer-zh"),
459
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
460
- # punc_model = os.path.join(ckpt_dir, "ct-punc"),
461
- # spk_model = os.path.join(ckpt_dir, "cam++"),
462
- disable_update=True,
463
- ) # following seed-tts setting
464
- elif lang == "en":
465
- from faster_whisper import WhisperModel
466
-
467
- model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
468
- model = WhisperModel(model_size, device="cuda", compute_type="float16")
469
- return model
470
-
471
-
472
- # WER Evaluation, the way Seed-TTS does
473
-
474
-
475
- def run_asr_wer(args):
476
- rank, lang, test_set, ckpt_dir = args
477
-
478
- if lang == "zh":
479
- import zhconv
480
-
481
- torch.cuda.set_device(rank)
482
- elif lang == "en":
483
- os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
484
- else:
485
- raise NotImplementedError(
486
- "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
487
- )
488
-
489
- asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
490
-
491
- from zhon.hanzi import punctuation
492
-
493
- punctuation_all = punctuation + string.punctuation
494
- wers = []
495
-
496
- from jiwer import compute_measures
497
-
498
- for gen_wav, prompt_wav, truth in tqdm(test_set):
499
- if lang == "zh":
500
- res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
501
- hypo = res[0]["text"]
502
- hypo = zhconv.convert(hypo, "zh-cn")
503
- elif lang == "en":
504
- segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
505
- hypo = ""
506
- for segment in segments:
507
- hypo = hypo + " " + segment.text
508
-
509
- # raw_truth = truth
510
- # raw_hypo = hypo
511
-
512
- for x in punctuation_all:
513
- truth = truth.replace(x, "")
514
- hypo = hypo.replace(x, "")
515
-
516
- truth = truth.replace(" ", " ")
517
- hypo = hypo.replace(" ", " ")
518
-
519
- if lang == "zh":
520
- truth = " ".join([x for x in truth])
521
- hypo = " ".join([x for x in hypo])
522
- elif lang == "en":
523
- truth = truth.lower()
524
- hypo = hypo.lower()
525
-
526
- measures = compute_measures(truth, hypo)
527
- wer = measures["wer"]
528
-
529
- # ref_list = truth.split(" ")
530
- # subs = measures["substitutions"] / len(ref_list)
531
- # dele = measures["deletions"] / len(ref_list)
532
- # inse = measures["insertions"] / len(ref_list)
533
-
534
- wers.append(wer)
535
-
536
- return wers
537
-
538
-
539
- # SIM Evaluation
540
-
541
-
542
- def run_sim(args):
543
- rank, test_set, ckpt_dir = args
544
- device = f"cuda:{rank}"
545
-
546
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
547
- state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
548
- model.load_state_dict(state_dict["model"], strict=False)
549
-
550
- use_gpu = True if torch.cuda.is_available() else False
551
- if use_gpu:
552
- model = model.cuda(device)
553
- model.eval()
554
-
555
- sim_list = []
556
- for wav1, wav2, truth in tqdm(test_set):
557
- wav1, sr1 = torchaudio.load(wav1)
558
- wav2, sr2 = torchaudio.load(wav2)
559
-
560
- resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
561
- resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
562
- wav1 = resample1(wav1)
563
- wav2 = resample2(wav2)
564
-
565
- if use_gpu:
566
- wav1 = wav1.cuda(device)
567
- wav2 = wav2.cuda(device)
568
- with torch.no_grad():
569
- emb1 = model(wav1)
570
- emb2 = model(wav2)
571
-
572
- sim = F.cosine_similarity(emb1, emb2)[0].item()
573
- # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
574
- sim_list.append(sim)
575
-
576
- return sim_list
577
-
578
-
579
  # filter func for dirty data with many repetitions
580
 
581
 
@@ -588,35 +182,3 @@ def repetition_found(text, length=2, tolerance=10):
588
  if count > tolerance:
589
  return True
590
  return False
591
-
592
-
593
- # load model checkpoint for inference
594
-
595
-
596
- def load_checkpoint(model, ckpt_path, device, use_ema=True):
597
- if device == "cuda":
598
- model = model.half()
599
-
600
- ckpt_type = ckpt_path.split(".")[-1]
601
- if ckpt_type == "safetensors":
602
- from safetensors.torch import load_file
603
-
604
- checkpoint = load_file(ckpt_path)
605
- else:
606
- checkpoint = torch.load(ckpt_path, weights_only=True)
607
-
608
- if use_ema:
609
- if ckpt_type == "safetensors":
610
- checkpoint = {"ema_model_state_dict": checkpoint}
611
- checkpoint["model_state_dict"] = {
612
- k.replace("ema_model.", ""): v
613
- for k, v in checkpoint["ema_model_state_dict"].items()
614
- if k not in ["initted", "step"]
615
- }
616
- model.load_state_dict(checkpoint["model_state_dict"])
617
- else:
618
- if ckpt_type == "safetensors":
619
- checkpoint = {"model_state_dict": checkpoint}
620
- model.load_state_dict(checkpoint["model_state_dict"])
621
-
622
- return model.to(device)
 
1
  from __future__ import annotations
2
 
3
  import os
 
4
  import random
 
5
  from importlib.resources import files
 
6
  from collections import defaultdict
7
 
 
 
 
 
 
8
  import torch
 
9
  from torch.nn.utils.rnn import pad_sequence
 
10
 
11
  import jieba
12
  from pypinyin import lazy_pinyin, Style
13
 
 
 
 
14
 
15
  # seed everything
16
 
 
170
  return final_text_list
171
 
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # filter func for dirty data with many repetitions
174
 
175
 
 
182
  if count > tolerance:
183
  return True
184
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5_tts/{scripts β†’ train/datasets}/prepare_csv_wavs.py RENAMED
@@ -1,138 +1,138 @@
1
- import sys
2
- import os
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- from pathlib import Path
7
- import json
8
- import shutil
9
- import argparse
10
-
11
- import csv
12
- import torchaudio
13
- from tqdm import tqdm
14
- from datasets.arrow_writer import ArrowWriter
15
-
16
- from f5_tts.model.utils import (
17
- convert_char_to_pinyin,
18
- )
19
-
20
- PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
21
-
22
-
23
- def is_csv_wavs_format(input_dataset_dir):
24
- fpath = Path(input_dataset_dir)
25
- metadata = fpath / "metadata.csv"
26
- wavs = fpath / "wavs"
27
- return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
28
-
29
-
30
- def prepare_csv_wavs_dir(input_dir):
31
- assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
32
- input_dir = Path(input_dir)
33
- metadata_path = input_dir / "metadata.csv"
34
- audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
35
-
36
- sub_result, durations = [], []
37
- vocab_set = set()
38
- polyphone = True
39
- for audio_path, text in audio_path_text_pairs:
40
- if not Path(audio_path).exists():
41
- print(f"audio {audio_path} not found, skipping")
42
- continue
43
- audio_duration = get_audio_duration(audio_path)
44
- # assume tokenizer = "pinyin" ("pinyin" | "char")
45
- text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
46
- sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
47
- durations.append(audio_duration)
48
- vocab_set.update(list(text))
49
-
50
- return sub_result, durations, vocab_set
51
-
52
-
53
- def get_audio_duration(audio_path):
54
- audio, sample_rate = torchaudio.load(audio_path)
55
- num_channels = audio.shape[0]
56
- return audio.shape[1] / (sample_rate * num_channels)
57
-
58
-
59
- def read_audio_text_pairs(csv_file_path):
60
- audio_text_pairs = []
61
-
62
- parent = Path(csv_file_path).parent
63
- with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
64
- reader = csv.reader(csvfile, delimiter="|")
65
- next(reader) # Skip the header row
66
- for row in reader:
67
- if len(row) >= 2:
68
- audio_file = row[0].strip() # First column: audio file path
69
- text = row[1].strip() # Second column: text
70
- audio_file_path = parent / audio_file
71
- audio_text_pairs.append((audio_file_path.as_posix(), text))
72
-
73
- return audio_text_pairs
74
-
75
-
76
- def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
77
- out_dir = Path(out_dir)
78
- # save preprocessed dataset to disk
79
- out_dir.mkdir(exist_ok=True, parents=True)
80
- print(f"\nSaving to {out_dir} ...")
81
-
82
- # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
83
- # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
84
- raw_arrow_path = out_dir / "raw.arrow"
85
- with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
86
- for line in tqdm(result, desc="Writing to raw.arrow ..."):
87
- writer.write(line)
88
-
89
- # dup a json separately saving duration in case for DynamicBatchSampler ease
90
- dur_json_path = out_dir / "duration.json"
91
- with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
92
- json.dump({"duration": duration_list}, f, ensure_ascii=False)
93
-
94
- # vocab map, i.e. tokenizer
95
- # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
96
- # if tokenizer == "pinyin":
97
- # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
98
- voca_out_path = out_dir / "vocab.txt"
99
- with open(voca_out_path.as_posix(), "w") as f:
100
- for vocab in sorted(text_vocab_set):
101
- f.write(vocab + "\n")
102
-
103
- if is_finetune:
104
- file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
105
- shutil.copy2(file_vocab_finetune, voca_out_path)
106
- else:
107
- with open(voca_out_path, "w") as f:
108
- for vocab in sorted(text_vocab_set):
109
- f.write(vocab + "\n")
110
-
111
- dataset_name = out_dir.stem
112
- print(f"\nFor {dataset_name}, sample count: {len(result)}")
113
- print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
114
- print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
115
-
116
-
117
- def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
118
- if is_finetune:
119
- assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
120
- sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
121
- save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
122
-
123
-
124
- def cli():
125
- # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
126
- # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
127
- parser = argparse.ArgumentParser(description="Prepare and save dataset.")
128
- parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
129
- parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
130
- parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
131
-
132
- args = parser.parse_args()
133
-
134
- prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
135
-
136
-
137
- if __name__ == "__main__":
138
- cli()
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ from pathlib import Path
7
+ import json
8
+ import shutil
9
+ import argparse
10
+
11
+ import csv
12
+ import torchaudio
13
+ from tqdm import tqdm
14
+ from datasets.arrow_writer import ArrowWriter
15
+
16
+ from f5_tts.model.utils import (
17
+ convert_char_to_pinyin,
18
+ )
19
+
20
+ PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
21
+
22
+
23
+ def is_csv_wavs_format(input_dataset_dir):
24
+ fpath = Path(input_dataset_dir)
25
+ metadata = fpath / "metadata.csv"
26
+ wavs = fpath / "wavs"
27
+ return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
28
+
29
+
30
+ def prepare_csv_wavs_dir(input_dir):
31
+ assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
32
+ input_dir = Path(input_dir)
33
+ metadata_path = input_dir / "metadata.csv"
34
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
35
+
36
+ sub_result, durations = [], []
37
+ vocab_set = set()
38
+ polyphone = True
39
+ for audio_path, text in audio_path_text_pairs:
40
+ if not Path(audio_path).exists():
41
+ print(f"audio {audio_path} not found, skipping")
42
+ continue
43
+ audio_duration = get_audio_duration(audio_path)
44
+ # assume tokenizer = "pinyin" ("pinyin" | "char")
45
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
46
+ sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
47
+ durations.append(audio_duration)
48
+ vocab_set.update(list(text))
49
+
50
+ return sub_result, durations, vocab_set
51
+
52
+
53
+ def get_audio_duration(audio_path):
54
+ audio, sample_rate = torchaudio.load(audio_path)
55
+ num_channels = audio.shape[0]
56
+ return audio.shape[1] / (sample_rate * num_channels)
57
+
58
+
59
+ def read_audio_text_pairs(csv_file_path):
60
+ audio_text_pairs = []
61
+
62
+ parent = Path(csv_file_path).parent
63
+ with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
64
+ reader = csv.reader(csvfile, delimiter="|")
65
+ next(reader) # Skip the header row
66
+ for row in reader:
67
+ if len(row) >= 2:
68
+ audio_file = row[0].strip() # First column: audio file path
69
+ text = row[1].strip() # Second column: text
70
+ audio_file_path = parent / audio_file
71
+ audio_text_pairs.append((audio_file_path.as_posix(), text))
72
+
73
+ return audio_text_pairs
74
+
75
+
76
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
77
+ out_dir = Path(out_dir)
78
+ # save preprocessed dataset to disk
79
+ out_dir.mkdir(exist_ok=True, parents=True)
80
+ print(f"\nSaving to {out_dir} ...")
81
+
82
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
83
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
84
+ raw_arrow_path = out_dir / "raw.arrow"
85
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
86
+ for line in tqdm(result, desc="Writing to raw.arrow ..."):
87
+ writer.write(line)
88
+
89
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
90
+ dur_json_path = out_dir / "duration.json"
91
+ with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
92
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
93
+
94
+ # vocab map, i.e. tokenizer
95
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
96
+ # if tokenizer == "pinyin":
97
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
98
+ voca_out_path = out_dir / "vocab.txt"
99
+ with open(voca_out_path.as_posix(), "w") as f:
100
+ for vocab in sorted(text_vocab_set):
101
+ f.write(vocab + "\n")
102
+
103
+ if is_finetune:
104
+ file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
105
+ shutil.copy2(file_vocab_finetune, voca_out_path)
106
+ else:
107
+ with open(voca_out_path, "w") as f:
108
+ for vocab in sorted(text_vocab_set):
109
+ f.write(vocab + "\n")
110
+
111
+ dataset_name = out_dir.stem
112
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
113
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
114
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
115
+
116
+
117
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
118
+ if is_finetune:
119
+ assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
120
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
121
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
122
+
123
+
124
+ def cli():
125
+ # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
126
+ # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
127
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
128
+ parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
129
+ parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
130
+ parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
131
+
132
+ args = parser.parse_args()
133
+
134
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ cli()
src/f5_tts/{scripts β†’ train/datasets}/prepare_emilia.py RENAMED
File without changes
src/f5_tts/{scripts β†’ train/datasets}/prepare_wenetspeech4tts.py RENAMED
File without changes