lpscr commited on
Commit
4172c49
Β·
unverified Β·
1 Parent(s): 22c553f

gradio_finetune

Browse files
Files changed (2) hide show
  1. finetune-cli.py +93 -0
  2. finetune_gradio.py +560 -0
finetune-cli.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
3
+ from model.utils import get_tokenizer
4
+ from model.dataset import load_dataset
5
+
6
+ # -------------------------- Dataset Settings --------------------------- #
7
+ target_sample_rate = 24000
8
+ n_mel_channels = 100
9
+ hop_length = 256
10
+
11
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
12
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
13
+
14
+ # -------------------------- Argument Parsing --------------------------- #
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description='Train CFM Model')
17
+
18
+ parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name')
19
+ parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use')
20
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
21
+ parser.add_argument('--batch_size_per_gpu', type=int, default=400, help='Batch size per GPU')
22
+ parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type')
23
+ parser.add_argument('--max_samples', type=int, default=64, help='Max sequences per batch')
24
+ parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps')
25
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
26
+ parser.add_argument('--epochs', type=int, default=11, help='Number of training epochs')
27
+ parser.add_argument('--num_warmup_updates', type=int, default=200, help='Warmup steps')
28
+ parser.add_argument('--save_per_updates', type=int, default=800, help='Save checkpoint every X steps')
29
+ parser.add_argument('--last_per_steps', type=int, default=400, help='Save last checkpoint every X steps')
30
+
31
+ return parser.parse_args()
32
+
33
+ # -------------------------- Training Settings -------------------------- #
34
+
35
+ def main():
36
+ args = parse_args()
37
+
38
+ # Model parameters based on experiment name
39
+ if args.exp_name == "F5TTS_Base":
40
+ wandb_resume_id = None
41
+ model_cls = DiT
42
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
43
+ elif args.exp_name == "E2TTS_Base":
44
+ wandb_resume_id = None
45
+ model_cls = UNetT
46
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
47
+
48
+ # Use the dataset_name provided in the command line
49
+ tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path
50
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
51
+
52
+ mel_spec_kwargs = dict(
53
+ target_sample_rate=target_sample_rate,
54
+ n_mel_channels=n_mel_channels,
55
+ hop_length=hop_length,
56
+ )
57
+
58
+ e2tts = CFM(
59
+ transformer=model_cls(
60
+ **model_cfg,
61
+ text_num_embeds=vocab_size,
62
+ mel_dim=n_mel_channels
63
+ ),
64
+ mel_spec_kwargs=mel_spec_kwargs,
65
+ vocab_char_map=vocab_char_map,
66
+ )
67
+
68
+ trainer = Trainer(
69
+ e2tts,
70
+ args.epochs,
71
+ args.learning_rate,
72
+ num_warmup_updates=args.num_warmup_updates,
73
+ save_per_updates=args.save_per_updates,
74
+ checkpoint_path=f'ckpts/{args.exp_name}',
75
+ batch_size=args.batch_size_per_gpu,
76
+ batch_size_type=args.batch_size_type,
77
+ max_samples=args.max_samples,
78
+ grad_accumulation_steps=args.grad_accumulation_steps,
79
+ max_grad_norm=args.max_grad_norm,
80
+ wandb_project="CFM-TTS",
81
+ wandb_run_name=args.exp_name,
82
+ wandb_resume_id=wandb_resume_id,
83
+ last_per_steps=args.last_per_steps,
84
+ )
85
+
86
+ train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
87
+ trainer.train(train_dataset,
88
+ resumable_with_seed=666 # seed for shuffling dataset
89
+ )
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
finetune_gradio.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ os.chdir(r"C:\PythonApps\ff5ttsmy\F5-TTS")
3
+
4
+ from pydub import silence,AudioSegment
5
+ from transformers import pipeline
6
+ import gradio as gr
7
+ import torch
8
+ import click
9
+ import tempfile
10
+ import torchaudio
11
+ from glob import glob
12
+ import librosa
13
+ import numpy as np
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import shutil
17
+ import time
18
+
19
+ import json
20
+ from datasets import Dataset
21
+ from model.utils import convert_char_to_pinyin
22
+ import signal
23
+ import psutil
24
+ import platform
25
+ import subprocess
26
+ from subprocess import Popen
27
+
28
+ training_process = None
29
+ system = platform.system()
30
+ python_executable = sys.executable or "python"
31
+
32
+ path_data="data"
33
+
34
+ device = (
35
+ "cuda"
36
+ if torch.cuda.is_available()
37
+ else "mps" if torch.backends.mps.is_available() else "cpu"
38
+ )
39
+
40
+ pipe = None
41
+
42
+ # Load metadata
43
+ def get_audio_duration(audio_path):
44
+ """Calculate the duration of an audio file."""
45
+ audio, sample_rate = torchaudio.load(audio_path)
46
+ num_channels = audio.shape[0]
47
+ return audio.shape[1] / (sample_rate * num_channels)
48
+
49
+ def clear_text(text):
50
+ """Clean and prepare text by lowering the case and stripping whitespace."""
51
+ return text.lower().strip()
52
+
53
+ def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
54
+ padding = (int(frame_length // 2), int(frame_length // 2))
55
+ y = np.pad(y, padding, mode=pad_mode)
56
+
57
+ axis = -1
58
+ # put our new within-frame axis at the end for now
59
+ out_strides = y.strides + tuple([y.strides[axis]])
60
+ # Reduce the shape on the framing axis
61
+ x_shape_trimmed = list(y.shape)
62
+ x_shape_trimmed[axis] -= frame_length - 1
63
+ out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
64
+ xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
65
+ if axis < 0:
66
+ target_axis = axis - 1
67
+ else:
68
+ target_axis = axis + 1
69
+ xw = np.moveaxis(xw, -1, target_axis)
70
+ # Downsample along the target axis
71
+ slices = [slice(None)] * xw.ndim
72
+ slices[axis] = slice(0, None, hop_length)
73
+ x = xw[tuple(slices)]
74
+
75
+ # Calculate power
76
+ power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
77
+
78
+ return np.sqrt(power)
79
+
80
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
81
+ def __init__(
82
+ self,
83
+ sr: int,
84
+ threshold: float = -40.0,
85
+ min_length: int = 5000,
86
+ min_interval: int = 300,
87
+ hop_size: int = 20,
88
+ max_sil_kept: int = 5000,
89
+ ):
90
+ if not min_length >= min_interval >= hop_size:
91
+ raise ValueError(
92
+ "The following condition must be satisfied: min_length >= min_interval >= hop_size"
93
+ )
94
+ if not max_sil_kept >= hop_size:
95
+ raise ValueError(
96
+ "The following condition must be satisfied: max_sil_kept >= hop_size"
97
+ )
98
+ min_interval = sr * min_interval / 1000
99
+ self.threshold = 10 ** (threshold / 20.0)
100
+ self.hop_size = round(sr * hop_size / 1000)
101
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
102
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
103
+ self.min_interval = round(min_interval / self.hop_size)
104
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
105
+
106
+ def _apply_slice(self, waveform, begin, end):
107
+ if len(waveform.shape) > 1:
108
+ return waveform[
109
+ :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
110
+ ]
111
+ else:
112
+ return waveform[
113
+ begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
114
+ ]
115
+
116
+ # @timeit
117
+ def slice(self, waveform):
118
+ if len(waveform.shape) > 1:
119
+ samples = waveform.mean(axis=0)
120
+ else:
121
+ samples = waveform
122
+ if samples.shape[0] <= self.min_length:
123
+ return [waveform]
124
+ rms_list = get_rms(
125
+ y=samples, frame_length=self.win_size, hop_length=self.hop_size
126
+ ).squeeze(0)
127
+ sil_tags = []
128
+ silence_start = None
129
+ clip_start = 0
130
+ for i, rms in enumerate(rms_list):
131
+ # Keep looping while frame is silent.
132
+ if rms < self.threshold:
133
+ # Record start of silent frames.
134
+ if silence_start is None:
135
+ silence_start = i
136
+ continue
137
+ # Keep looping while frame is not silent and silence start has not been recorded.
138
+ if silence_start is None:
139
+ continue
140
+ # Clear recorded silence start if interval is not enough or clip is too short
141
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
142
+ need_slice_middle = (
143
+ i - silence_start >= self.min_interval
144
+ and i - clip_start >= self.min_length
145
+ )
146
+ if not is_leading_silence and not need_slice_middle:
147
+ silence_start = None
148
+ continue
149
+ # Need slicing. Record the range of silent frames to be removed.
150
+ if i - silence_start <= self.max_sil_kept:
151
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
152
+ if silence_start == 0:
153
+ sil_tags.append((0, pos))
154
+ else:
155
+ sil_tags.append((pos, pos))
156
+ clip_start = pos
157
+ elif i - silence_start <= self.max_sil_kept * 2:
158
+ pos = rms_list[
159
+ i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
160
+ ].argmin()
161
+ pos += i - self.max_sil_kept
162
+ pos_l = (
163
+ rms_list[
164
+ silence_start : silence_start + self.max_sil_kept + 1
165
+ ].argmin()
166
+ + silence_start
167
+ )
168
+ pos_r = (
169
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
170
+ + i
171
+ - self.max_sil_kept
172
+ )
173
+ if silence_start == 0:
174
+ sil_tags.append((0, pos_r))
175
+ clip_start = pos_r
176
+ else:
177
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
178
+ clip_start = max(pos_r, pos)
179
+ else:
180
+ pos_l = (
181
+ rms_list[
182
+ silence_start : silence_start + self.max_sil_kept + 1
183
+ ].argmin()
184
+ + silence_start
185
+ )
186
+ pos_r = (
187
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
188
+ + i
189
+ - self.max_sil_kept
190
+ )
191
+ if silence_start == 0:
192
+ sil_tags.append((0, pos_r))
193
+ else:
194
+ sil_tags.append((pos_l, pos_r))
195
+ clip_start = pos_r
196
+ silence_start = None
197
+ # Deal with trailing silence.
198
+ total_frames = rms_list.shape[0]
199
+ if (
200
+ silence_start is not None
201
+ and total_frames - silence_start >= self.min_interval
202
+ ):
203
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
204
+ pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
205
+ sil_tags.append((pos, total_frames + 1))
206
+ # Apply and return slices.
207
+ ####ιŸ³ι’‘+θ΅·ε§‹ζ—Άι—΄+η»ˆζ­’ζ—Άι—΄
208
+ if len(sil_tags) == 0:
209
+ return [[waveform,0,int(total_frames*self.hop_size)]]
210
+ else:
211
+ chunks = []
212
+ if sil_tags[0][0] > 0:
213
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
214
+ for i in range(len(sil_tags) - 1):
215
+ chunks.append(
216
+ [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)]
217
+ )
218
+ if sil_tags[-1][1] < total_frames:
219
+ chunks.append(
220
+ [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)]
221
+ )
222
+ return chunks
223
+
224
+ #terminal
225
+ def terminate_process_tree(pid, including_parent=True):
226
+ try:
227
+ parent = psutil.Process(pid)
228
+ except psutil.NoSuchProcess:
229
+ # Process already terminated
230
+ return
231
+
232
+ children = parent.children(recursive=True)
233
+ for child in children:
234
+ try:
235
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
236
+ except OSError:
237
+ pass
238
+ if including_parent:
239
+ try:
240
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
241
+ except OSError:
242
+ pass
243
+
244
+ def terminate_process(pid):
245
+ if system == "Windows":
246
+ cmd = f"taskkill /t /f /pid {pid}"
247
+ os.system(cmd)
248
+ else:
249
+ terminate_process_tree(pid)
250
+
251
+
252
+ def start_training(
253
+ dataset_name="",
254
+ exp_name="F5TTS_Base", # Default experiment name
255
+ learning_rate=1e-4, # Default learning rate
256
+ batch_size_per_gpu=400, # Default batch size per GPU
257
+ batch_size_type="frame", # Default batch size type
258
+ max_samples=64, # Default max sequences per batch
259
+ grad_accumulation_steps=1, # Default gradient accumulation steps
260
+ max_grad_norm=1.0, # Default max gradient norm
261
+ epochs=11, # Default number of training epochs
262
+ num_warmup_updates=200, # Default number of warmup updates
263
+ save_per_updates=400, # Default save interval for checkpoints
264
+ last_per_steps=800, # Default save interval for last checkpoint
265
+ ):
266
+
267
+ global training_process
268
+
269
+ # Check if a training process is already running
270
+ if training_process is not None:
271
+ return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
272
+
273
+ yield "start train",gr.update(interactive=False),gr.update(interactive=False)
274
+
275
+ # Command to run the training script with the specified arguments
276
+ cmd = f"{python_executable} finetune-cli.py --exp_name {exp_name} " \
277
+ f"--learning_rate {learning_rate} " \
278
+ f"--batch_size_per_gpu {batch_size_per_gpu} " \
279
+ f"--batch_size_type {batch_size_type} " \
280
+ f"--max_samples {max_samples} " \
281
+ f"--grad_accumulation_steps {grad_accumulation_steps} " \
282
+ f"--max_grad_norm {max_grad_norm} " \
283
+ f"--epochs {epochs} " \
284
+ f"--num_warmup_updates {num_warmup_updates} " \
285
+ f"--save_per_updates {save_per_updates} " \
286
+ f"--last_per_steps {last_per_steps} " \
287
+ f"--dataset_name {dataset_name}"
288
+
289
+ try:
290
+ # Start the training process
291
+ training_process = subprocess.Popen(cmd, shell=True)
292
+
293
+ time.sleep(5)
294
+ yield "check terminal for wandb",gr.update(interactive=False),gr.update(interactive=True)
295
+
296
+ # Wait for the training process to finish
297
+ training_process.wait()
298
+ time.sleep(1)
299
+
300
+ if training_process is None:
301
+ text_info = 'train stop'
302
+ else:
303
+ text_info = "train complete !"
304
+
305
+ except Exception as e: # Catch all exceptions
306
+ # Ensure that we reset the training process variable in case of an error
307
+ text_info=f"An error occurred: {str(e)}"
308
+
309
+ training_process=None
310
+
311
+ yield text_info,gr.update(interactive=True),gr.update(interactive=False)
312
+
313
+ def stop_training():
314
+ global training_process
315
+ if training_process is None:return f"Train not run !",gr.update(interactive=True),gr.update(interactive=False)
316
+ terminate_process_tree(training_process.pid)
317
+ training_process = None
318
+ return 'train stop',gr.update(interactive=True),gr.update(interactive=False)
319
+
320
+ def create_data_project(name):
321
+ name+="_pinyin"
322
+ os.makedirs(os.path.join(path_data,name),exist_ok=True)
323
+ os.makedirs(os.path.join(path_data,name,"dataset"),exist_ok=True)
324
+
325
+ def transcribe(file_audio,language="english"):
326
+ global pipe
327
+
328
+ if pipe is None:
329
+ pipe = pipeline("automatic-speech-recognition",model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16,device=device)
330
+
331
+ text_transcribe = pipe(
332
+ file_audio,
333
+ chunk_length_s=30,
334
+ batch_size=128,
335
+ generate_kwargs={"task": "transcribe","language": language},
336
+ return_timestamps=False,
337
+ )["text"].strip()
338
+ return text_transcribe
339
+
340
+ def transcribe_all(name_project,audio_file,language,user=False):
341
+ name_project+="_pinyin"
342
+ path_project= os.path.join(path_data,name_project)
343
+ path_dataset = os.path.join(path_project,"dataset")
344
+ path_project_wavs = os.path.join(path_project,"wavs")
345
+ file_metadata = os.path.join(path_project,"metadata.csv")
346
+
347
+ if os.path.isdir(path_project_wavs):
348
+ shutil.rmtree(path_project_wavs)
349
+
350
+ if os.path.isfile(file_metadata):
351
+ os.remove(file_metadata)
352
+
353
+ os.makedirs(path_project_wavs,exist_ok=True)
354
+
355
+ if user:
356
+ file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
357
+ else:
358
+ file_audios = [audio_file]
359
+
360
+ print([file_audios])
361
+
362
+ alpha = 0.5
363
+ _max = 1.0
364
+ slicer = Slicer(24000)
365
+
366
+ num = 0
367
+ data=""
368
+ for file_audio in tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
369
+
370
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
371
+
372
+ list_slicer=slicer.slice(audio)
373
+ for chunk, start, end in tqdm(list_slicer,total=len(list_slicer), desc="slicer files"):
374
+ name_segment = os.path.join(f"segment_{num}")
375
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
376
+
377
+ tmp_max = np.abs(chunk).max()
378
+ if(tmp_max>1):chunk/=tmp_max
379
+ chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
380
+ wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
381
+
382
+ text=transcribe(file_segment,language)
383
+ text = text.lower().strip().replace('"',"")
384
+
385
+ data+= f"{name_segment}|{text}\n"
386
+
387
+ num+=1
388
+
389
+ with open(file_metadata,"w",encoding="utf-8") as f:
390
+ f.write(data)
391
+
392
+ return f"transcribe complete samples : {num} in path {path_project_wavs}"
393
+
394
+ def create_metadata(name_project):
395
+ name_project+="_pinyin"
396
+ path_project= os.path.join(path_data,name_project)
397
+ path_project_wavs = os.path.join(path_project,"wavs")
398
+ path_raw = os.path.join(path_project,"raw")
399
+ file_metadata = os.path.join(path_project,"metadata.csv")
400
+ file_duration = os.path.join(path_project,"duration.json")
401
+ file_vocab = os.path.join(path_project,"vocab.txt")
402
+
403
+ with open(file_metadata,"r",encoding="utf-8") as f:
404
+ data=f.read()
405
+
406
+ audio_path_list=[]
407
+ text_list=[]
408
+ duration_list=[]
409
+
410
+ for line in data.split("\n"):
411
+ sp_line=line.split("|")
412
+ if len(sp_line)!=2:continue
413
+ name_audio,text = sp_line[:2]
414
+ file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
415
+ duraction = get_audio_duration(file_audio)
416
+ if duraction<2 and duraction>15:continue
417
+ if len(text)<4:continue
418
+
419
+ text = clear_text(text)
420
+
421
+ audio_path_list.append(file_audio)
422
+ duration_list.append(duraction)
423
+ text_list.append(text)
424
+
425
+ tokenizer="pinyin"
426
+ polyphone=True
427
+ if tokenizer=="pinyin":
428
+ text_list = [convert_char_to_pinyin([text], polyphone = polyphone)[0] for text in text_list]
429
+
430
+ dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
431
+ dataset.save_to_disk(path_raw, max_shard_size="2GB") # arrow format
432
+
433
+ with open(file_duration, 'w', encoding='utf-8') as f:
434
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
435
+
436
+ file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
437
+ shutil.copy2(file_vocab_finetune, file_vocab)
438
+
439
+ return f"prepare complete samples : {len(text_list)} in path {path_raw}"
440
+
441
+ def check_user(value):
442
+ return gr.update(visible=not value),gr.update(visible=value)
443
+
444
+ with gr.Blocks() as app:
445
+
446
+ with gr.Row():
447
+ project_name=gr.Textbox(label="project name",value="my_speak")
448
+ bt_create=gr.Button("create new project")
449
+
450
+ bt_create.click(fn=create_data_project,inputs=[project_name])
451
+
452
+ with gr.Tabs():
453
+
454
+
455
+ with gr.TabItem("transcribe Data"):
456
+
457
+
458
+ ch_manual = gr.Checkbox(label="user",value=False)
459
+
460
+ mark_info_transcribe=gr.Markdown(
461
+ """```plaintext
462
+ Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
463
+
464
+ my_speak/
465
+ β”‚
466
+ └── dataset/
467
+ β”œβ”€β”€ audio1.wav
468
+ └── audio2.wav
469
+ ...
470
+ ```""",visible=False)
471
+
472
+ audio_speaker = gr.Audio(label="voice",type="filepath")
473
+ txt_lang = gr.Text(label="Language",value="english")
474
+ bt_transcribe=bt_create=gr.Button("transcribe")
475
+ txt_info_transcribe=gr.Text(label="info",value="")
476
+ bt_transcribe.click(fn=transcribe_all,inputs=[project_name,audio_speaker,txt_lang,ch_manual],outputs=[txt_info_transcribe])
477
+ ch_manual.change(fn=check_user,inputs=[ch_manual],outputs=[audio_speaker,mark_info_transcribe])
478
+
479
+ with gr.TabItem("prepare Data"):
480
+ gr.Markdown(
481
+ """```plaintext
482
+ place all your wavs folder and your metadata.csv file in {your name project}
483
+ my_speak/
484
+ β”‚
485
+ β”œβ”€β”€ wavs/
486
+ β”‚ β”œβ”€β”€ audio1.wav
487
+ β”‚ └── audio2.wav
488
+ | ...
489
+ β”‚
490
+ └── metadata.csv
491
+
492
+ file format metadata.csv
493
+
494
+ audio1|text1
495
+ audio2|text1
496
+ ...
497
+
498
+ ```""")
499
+
500
+ bt_prepare=bt_create=gr.Button("prepare")
501
+ txt_info_prepare=gr.Text(label="info",value="")
502
+ bt_prepare.click(fn=create_metadata,inputs=[project_name],outputs=[txt_info_prepare])
503
+
504
+ with gr.TabItem("train Data"):
505
+
506
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
507
+ learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-5)
508
+
509
+ with gr.Row():
510
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=408)
511
+ max_samples = gr.Number(label="Max Samples", value=64)
512
+ batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
513
+
514
+ with gr.Row():
515
+ grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
516
+ max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
517
+
518
+ with gr.Row():
519
+ epochs = gr.Number(label="Epochs", value=11)
520
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=200)
521
+
522
+ with gr.Row():
523
+ save_per_updates = gr.Number(label="Save per Updates", value=400)
524
+ last_per_steps = gr.Number(label="Last per Steps", value=800)
525
+
526
+ with gr.Row():
527
+ start_button = gr.Button("Start Training")
528
+ stop_button = gr.Button("Stop Training",interactive=False)
529
+
530
+ txt_info_train=gr.Text(label="info",value="")
531
+ start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps],outputs=[txt_info_train,start_button,stop_button])
532
+ stop_button.click(fn=stop_training,outputs=[txt_info_train,start_button,stop_button])
533
+
534
+
535
+ @click.command()
536
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
537
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
538
+ @click.option(
539
+ "--share",
540
+ "-s",
541
+ default=False,
542
+ is_flag=True,
543
+ help="Share the app via Gradio share link",
544
+ )
545
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
546
+ def main(port, host, share, api):
547
+ global app
548
+ print(f"Starting app...")
549
+ app.queue(api_open=api).launch(
550
+ server_name=host, server_port=port, share=share, show_api=api
551
+ )
552
+
553
+ if __name__ == "__main__":
554
+ name="my_speak"
555
+
556
+ #create_data_project(name)
557
+ #transcribe_all(name)
558
+ #create_metadata(name)
559
+
560
+ main()