lpscr commited on
Commit
549ee89
·
unverified ·
2 Parent(s): 3f3743e 18f526d

Merge branch 'SWivid:main' into main

Browse files
README.md CHANGED
@@ -2,8 +2,9 @@
2
 
3
  [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
- [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
6
- [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
 
7
  [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
8
  <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto">
9
 
@@ -52,7 +53,7 @@ python scripts/prepare_emilia.py
52
  python scripts/prepare_wenetspeech4tts.py
53
  ```
54
 
55
- ## Training
56
 
57
  Once your datasets are prepared, you can start the training process.
58
 
@@ -64,9 +65,11 @@ accelerate launch train.py
64
  ```
65
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
66
 
 
 
67
  ## Inference
68
 
69
- The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [ Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
70
 
71
  Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
72
  - To avoid possible inference failures, make sure you have seen through the following instructions.
@@ -90,6 +93,9 @@ python inference-cli.py \
90
  --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
91
  --ref_text "对,这就是我,万人敬仰的太乙真人。" \
92
  --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
 
 
 
93
  ```
94
 
95
  ### Gradio App
@@ -188,11 +194,13 @@ python scripts/eval_librispeech_test_clean.py
188
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
189
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
190
  - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
191
- - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
192
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
193
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
 
 
194
 
195
  ## Citation
 
196
  ```
197
  @article{chen-etal-2024-f5tts,
198
  title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
 
2
 
3
  [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
  [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
+ [![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
6
+ [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
+ [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
8
  [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
9
  <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto">
10
 
 
53
  python scripts/prepare_wenetspeech4tts.py
54
  ```
55
 
56
+ ## Training & Finetuning
57
 
58
  Once your datasets are prepared, you can start the training process.
59
 
 
65
  ```
66
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
67
 
68
+ Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
69
+
70
  ## Inference
71
 
72
+ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
73
 
74
  Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
75
  - To avoid possible inference failures, make sure you have seen through the following instructions.
 
93
  --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
94
  --ref_text "对,这就是我,万人敬仰的太乙真人。" \
95
  --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
96
+
97
+ # Multi voice
98
+ python inference-cli.py -c samples/story.toml
99
  ```
100
 
101
  ### Gradio App
 
194
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
195
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
196
  - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
 
197
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
198
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
199
+ - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
200
+ - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation of F5-TTS, with the MLX framework.
201
 
202
  ## Citation
203
+ If our work and codebase is useful for you, please cite as:
204
  ```
205
  @article{chen-etal-2024-f5tts,
206
  title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
finetune_gradio.py CHANGED
@@ -339,7 +339,7 @@ def transcribe(file_audio,language="english"):
339
  )["text"].strip()
340
  return text_transcribe
341
 
342
- def transcribe_all(name_project,audio_file,language,user=False,progress=gr.Progress()):
343
  name_project+="_pinyin"
344
  path_project= os.path.join(path_data,name_project)
345
  path_dataset = os.path.join(path_project,"dataset")
@@ -357,7 +357,7 @@ def transcribe_all(name_project,audio_file,language,user=False,progress=gr.Progr
357
  if user:
358
  file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
359
  else:
360
- file_audios = [audio_file]
361
 
362
  print([file_audios])
363
 
@@ -580,7 +580,7 @@ with gr.Blocks() as app:
580
  ...
581
  ```""",visible=False)
582
 
583
- audio_speaker = gr.Audio(label="voice",type="filepath")
584
  txt_lang = gr.Text(label="Language",value="english")
585
  bt_transcribe=bt_create=gr.Button("transcribe")
586
  txt_info_transcribe=gr.Text(label="info",value="")
@@ -686,4 +686,4 @@ if __name__ == "__main__":
686
  #transcribe_all(name)
687
  #create_metadata(name)
688
 
689
- main()
 
339
  )["text"].strip()
340
  return text_transcribe
341
 
342
+ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Progress()):
343
  name_project+="_pinyin"
344
  path_project= os.path.join(path_data,name_project)
345
  path_dataset = os.path.join(path_project,"dataset")
 
357
  if user:
358
  file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
359
  else:
360
+ file_audios = audio_files
361
 
362
  print([file_audios])
363
 
 
580
  ...
581
  ```""",visible=False)
582
 
583
+ audio_speaker = gr.File(label="voice",type="filepath",file_count="multiple")
584
  txt_lang = gr.Text(label="Language",value="english")
585
  bt_transcribe=bt_create=gr.Button("transcribe")
586
  txt_info_transcribe=gr.Text(label="info",value="")
 
686
  #transcribe_all(name)
687
  #create_metadata(name)
688
 
689
+ main()
gradio_app.py CHANGED
@@ -532,8 +532,8 @@ with gr.Blocks() as app_emotional:
532
  regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
533
  regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
534
 
535
- # Additional speech types (up to 9 more)
536
- max_speech_types = 10
537
  speech_type_names = []
538
  speech_type_audios = []
539
  speech_type_ref_texts = []
 
532
  regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
533
  regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
534
 
535
+ # Additional speech types (up to 99 more)
536
+ max_speech_types = 100
537
  speech_type_names = []
538
  speech_type_audios = []
539
  speech_type_ref_texts = []
inference-cli.py CHANGED
@@ -282,29 +282,12 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cr
282
 
283
  final_wave = new_wave
284
 
285
- with open(wave_path, "wb") as f:
286
- sf.write(f.name, final_wave, target_sample_rate)
287
- # Remove silence
288
- if remove_silence:
289
- aseg = AudioSegment.from_file(f.name)
290
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
291
- non_silent_wave = AudioSegment.silent(duration=0)
292
- for non_silent_seg in non_silent_segs:
293
- non_silent_wave += non_silent_seg
294
- aseg = non_silent_wave
295
- aseg.export(f.name, format="wav")
296
- print(f.name)
297
-
298
  # Create a combined spectrogram
299
  combined_spectrogram = np.concatenate(spectrograms, axis=1)
300
- save_spectrogram(combined_spectrogram, spectrogram_path)
301
- print(spectrogram_path)
302
-
303
 
304
- def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
305
-
306
- print(gen_text)
307
 
 
308
  print("Converting audio...")
309
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
310
  aseg = AudioSegment.from_file(ref_audio_orig)
@@ -340,7 +323,10 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
340
  print("Finished transcription")
341
  else:
342
  print("Using custom reference text...")
 
343
 
 
 
344
  # Add the functionality to ensure it ends with ". "
345
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
346
  if ref_text.endswith("."):
@@ -360,4 +346,47 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
360
  return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
361
 
362
 
363
- infer(ref_audio, ref_text, gen_text, model, remove_silence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  final_wave = new_wave
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  # Create a combined spectrogram
286
  combined_spectrogram = np.concatenate(spectrograms, axis=1)
 
 
 
287
 
288
+ return final_wave, combined_spectrogram
 
 
289
 
290
+ def process_voice(ref_audio_orig, ref_text):
291
  print("Converting audio...")
292
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
293
  aseg = AudioSegment.from_file(ref_audio_orig)
 
323
  print("Finished transcription")
324
  else:
325
  print("Using custom reference text...")
326
+ return ref_audio, ref_text
327
 
328
+ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
329
+ print(gen_text)
330
  # Add the functionality to ensure it ends with ". "
331
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
332
  if ref_text.endswith("."):
 
346
  return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
347
 
348
 
349
+ def process(ref_audio, ref_text, text_gen, model, remove_silence):
350
+ main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
351
+ if "voices" not in config:
352
+ voices = {"main": main_voice}
353
+ else:
354
+ voices = config["voices"]
355
+ voices["main"] = main_voice
356
+ for voice in voices:
357
+ voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text'])
358
+
359
+ generated_audio_segments = []
360
+ reg1 = r'(?=\[\w+\])'
361
+ chunks = re.split(reg1, text_gen)
362
+ reg2 = r'\[(\w+)\]'
363
+ for text in chunks:
364
+ match = re.match(reg2, text)
365
+ if not match or voice not in voices:
366
+ voice = "main"
367
+ else:
368
+ voice = match[1]
369
+ text = re.sub(reg2, "", text)
370
+ gen_text = text.strip()
371
+ ref_audio = voices[voice]['ref_audio']
372
+ ref_text = voices[voice]['ref_text']
373
+ print(f"Voice: {voice}")
374
+ audio, spectragram = infer(ref_audio, ref_text, gen_text, model, remove_silence)
375
+ generated_audio_segments.append(audio)
376
+
377
+ if generated_audio_segments:
378
+ final_wave = np.concatenate(generated_audio_segments)
379
+ with open(wave_path, "wb") as f:
380
+ sf.write(f.name, final_wave, target_sample_rate)
381
+ # Remove silence
382
+ if remove_silence:
383
+ aseg = AudioSegment.from_file(f.name)
384
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
385
+ non_silent_wave = AudioSegment.silent(duration=0)
386
+ for non_silent_seg in non_silent_segs:
387
+ non_silent_wave += non_silent_seg
388
+ aseg = non_silent_wave
389
+ aseg.export(f.name, format="wav")
390
+ print(f.name)
391
+
392
+ process(ref_audio, ref_text, gen_text, model, remove_silence)
model/trainer.py CHANGED
@@ -140,7 +140,7 @@ class Trainer:
140
  else:
141
  latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
 
145
  if self.is_main:
146
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
 
140
  else:
141
  latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
 
145
  if self.is_main:
146
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
model/utils.py CHANGED
@@ -509,7 +509,7 @@ def run_sim(args):
509
  device = f"cuda:{rank}"
510
 
511
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
- state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
513
  model.load_state_dict(state_dict['model'], strict=False)
514
 
515
  use_gpu=True if torch.cuda.is_available() else False
@@ -565,7 +565,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
565
  from safetensors.torch import load_file
566
  checkpoint = load_file(ckpt_path, device=device)
567
  else:
568
- checkpoint = torch.load(ckpt_path, map_location=device)
569
 
570
  if use_ema == True:
571
  ema_model = EMA(model, include_online_model = False).to(device)
 
509
  device = f"cuda:{rank}"
510
 
511
  model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
  model.load_state_dict(state_dict['model'], strict=False)
514
 
515
  use_gpu=True if torch.cuda.is_available() else False
 
565
  from safetensors.torch import load_file
566
  checkpoint = load_file(ckpt_path, device=device)
567
  else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
 
570
  if use_ema == True:
571
  ema_model = EMA(model, include_online_model = False).to(device)
samples/country.flac ADDED
Binary file (180 kB). View file
 
samples/main.flac ADDED
Binary file (279 kB). View file
 
samples/story.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
+ ref_audio = "samples/main.flac"
4
+ # If an empty "", transcribes the reference audio automatically.
5
+ ref_text = ""
6
+ gen_text = ""
7
+ # File with text to generate. Ignores the text above.
8
+ gen_file = "samples/story.txt"
9
+ remove_silence = true
10
+ output_dir = "samples"
11
+
12
+ [voices.town]
13
+ ref_audio = "samples/town.flac"
14
+ ref_text = ""
15
+
16
+ [voices.country]
17
+ ref_audio = "samples/country.flac"
18
+ ref_text = ""
19
+
samples/story.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
samples/town.flac ADDED
Binary file (229 kB). View file
 
scripts/eval_infer_batch.py CHANGED
@@ -127,7 +127,7 @@ local = False
127
  if local:
128
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
131
  vocos.load_state_dict(state_dict)
132
  vocos.eval()
133
  else:
 
127
  if local:
128
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
131
  vocos.load_state_dict(state_dict)
132
  vocos.eval()
133
  else:
scripts/prepare_csv_wavs.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from pathlib import Path
5
+ import json
6
+ import shutil
7
+ import argparse
8
+
9
+ import csv
10
+ import torchaudio
11
+ from tqdm import tqdm
12
+ from datasets.arrow_writer import ArrowWriter
13
+
14
+ from model.utils import (
15
+ convert_char_to_pinyin,
16
+ )
17
+
18
+ PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
19
+
20
+ def is_csv_wavs_format(input_dataset_dir):
21
+ fpath = Path(input_dataset_dir)
22
+ metadata = fpath / "metadata.csv"
23
+ wavs = fpath / 'wavs'
24
+ return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
25
+
26
+
27
+ def prepare_csv_wavs_dir(input_dir):
28
+ assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
29
+ input_dir = Path(input_dir)
30
+ metadata_path = input_dir / "metadata.csv"
31
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
32
+
33
+ sub_result, durations = [], []
34
+ vocab_set = set()
35
+ polyphone = True
36
+ for audio_path, text in audio_path_text_pairs:
37
+ if not Path(audio_path).exists():
38
+ print(f"audio {audio_path} not found, skipping")
39
+ continue
40
+ audio_duration = get_audio_duration(audio_path)
41
+ # assume tokenizer = "pinyin" ("pinyin" | "char")
42
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
43
+ sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
44
+ durations.append(audio_duration)
45
+ vocab_set.update(list(text))
46
+
47
+ return sub_result, durations, vocab_set
48
+
49
+ def get_audio_duration(audio_path):
50
+ audio, sample_rate = torchaudio.load(audio_path)
51
+ num_channels = audio.shape[0]
52
+ return audio.shape[1] / (sample_rate * num_channels)
53
+
54
+ def read_audio_text_pairs(csv_file_path):
55
+ audio_text_pairs = []
56
+
57
+ parent = Path(csv_file_path).parent
58
+ with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
59
+ reader = csv.reader(csvfile, delimiter='|')
60
+ next(reader) # Skip the header row
61
+ for row in reader:
62
+ if len(row) >= 2:
63
+ audio_file = row[0].strip() # First column: audio file path
64
+ text = row[1].strip() # Second column: text
65
+ audio_file_path = parent / audio_file
66
+ audio_text_pairs.append((audio_file_path.as_posix(), text))
67
+
68
+ return audio_text_pairs
69
+
70
+
71
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
72
+ out_dir = Path(out_dir)
73
+ # save preprocessed dataset to disk
74
+ out_dir.mkdir(exist_ok=True, parents=True)
75
+ print(f"\nSaving to {out_dir} ...")
76
+
77
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
78
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
79
+ raw_arrow_path = out_dir / "raw.arrow"
80
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
81
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
82
+ writer.write(line)
83
+
84
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
85
+ dur_json_path = out_dir / "duration.json"
86
+ with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
87
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
88
+
89
+ # vocab map, i.e. tokenizer
90
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
91
+ # if tokenizer == "pinyin":
92
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
93
+ voca_out_path = out_dir / "vocab.txt"
94
+ with open(voca_out_path.as_posix(), "w") as f:
95
+ for vocab in sorted(text_vocab_set):
96
+ f.write(vocab + "\n")
97
+
98
+ if is_finetune:
99
+ file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
100
+ shutil.copy2(file_vocab_finetune, voca_out_path)
101
+ else:
102
+ with open(voca_out_path, "w") as f:
103
+ for vocab in sorted(text_vocab_set):
104
+ f.write(vocab + "\n")
105
+
106
+ dataset_name = out_dir.stem
107
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
108
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
109
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
110
+
111
+
112
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
113
+ if is_finetune:
114
+ assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
115
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
116
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
117
+
118
+
119
+ def cli():
120
+ # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
121
+ # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
122
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
123
+ parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
124
+ parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
125
+ parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
126
+
127
+ args = parser.parse_args()
128
+
129
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
130
+
131
+ if __name__ == "__main__":
132
+ cli()
speech_edit.py CHANGED
@@ -85,8 +85,9 @@ local = False
85
  if local:
86
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
89
  vocos.load_state_dict(state_dict)
 
90
  vocos.eval()
91
  else:
92
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
85
  if local:
86
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
89
  vocos.load_state_dict(state_dict)
90
+
91
  vocos.eval()
92
  else:
93
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")