SWivid commited on
Commit
2ccfbdc
·
2 Parent(s): 5f6dcc7 d5c307b

Merge branch 'main' of github.com:lpscr/F5-TTS into lpscr-main

Browse files
src/f5_tts/model/trainer.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import os
4
  import gc
5
  from tqdm import tqdm
6
- import wandb
7
 
8
  import torch
9
  from torch.optim import AdamW
@@ -19,7 +19,6 @@ from f5_tts.model import CFM
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
 
22
-
23
  # trainer
24
 
25
 
@@ -39,6 +38,8 @@ class Trainer:
39
  max_grad_norm=1.0,
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
 
 
42
  wandb_project="test_e2-tts",
43
  wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
@@ -46,24 +47,29 @@ class Trainer:
46
  accelerate_kwargs: dict = dict(),
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
 
49
  ):
50
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
 
 
 
51
 
52
- logger = "wandb" if wandb.api.api_key else None
53
- print(f"Using logger: {logger}")
54
 
55
- self.accelerator = Accelerator(
56
- log_with=logger,
57
- kwargs_handlers=[ddp_kwargs],
58
- gradient_accumulation_steps=grad_accumulation_steps,
59
- **accelerate_kwargs,
60
- )
 
 
61
 
62
- if logger == "wandb":
63
  if exists(wandb_resume_id):
64
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
65
  else:
66
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
 
67
  self.accelerator.init_trackers(
68
  project_name=wandb_project,
69
  init_kwargs=init_kwargs,
@@ -80,12 +86,29 @@ class Trainer:
80
  "noise_scheduler": noise_scheduler,
81
  },
82
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  self.model = model
85
 
86
  if self.is_main:
87
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
88
-
89
  self.ema_model.to(self.accelerator.device)
90
 
91
  self.epochs = epochs
@@ -175,7 +198,32 @@ class Trainer:
175
  gc.collect()
176
  return step
177
 
 
 
 
 
 
 
 
 
 
 
178
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  if exists(resumable_with_seed):
180
  generator = torch.Generator()
181
  generator.manual_seed(resumable_with_seed)
@@ -259,6 +307,7 @@ class Trainer:
259
  for batch in progress_bar:
260
  with self.accelerator.accumulate(self.model):
261
  text_inputs = batch["text"]
 
262
  mel_spec = batch["mel"].permute(0, 2, 1)
263
  mel_lengths = batch["mel_lengths"]
264
 
@@ -270,6 +319,40 @@ class Trainer:
270
  loss, cond, pred = self.model(
271
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
272
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  self.accelerator.backward(loss)
274
 
275
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -285,7 +368,7 @@ class Trainer:
285
  global_step += 1
286
 
287
  if self.accelerator.is_local_main_process:
288
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
289
 
290
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
291
 
 
3
  import os
4
  import gc
5
  from tqdm import tqdm
6
+
7
 
8
  import torch
9
  from torch.optim import AdamW
 
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
 
 
22
  # trainer
23
 
24
 
 
38
  max_grad_norm=1.0,
39
  noise_scheduler: str | None = None,
40
  duration_predictor: torch.nn.Module | None = None,
41
+ logger: str = "wandb", # Add logger parameter wandb,tensorboard , none
42
+ log_dir: str = "logs", # Add log directory parameter
43
  wandb_project="test_e2-tts",
44
  wandb_run_name="test_run",
45
  wandb_resume_id: str = None,
 
47
  accelerate_kwargs: dict = dict(),
48
  ema_kwargs: dict = dict(),
49
  bnb_optimizer: bool = False,
50
+ export_samples=False,
51
  ):
52
+ # export audio and mel
53
+ self.export_samples = export_samples
54
+ if export_samples:
55
+ self.path_ckpts_project = checkpoint_path
56
 
57
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
 
58
 
59
+ self.logger = logger
60
+ if self.logger == "wandb":
61
+ self.accelerator = Accelerator(
62
+ log_with="wandb",
63
+ kwargs_handlers=[ddp_kwargs],
64
+ gradient_accumulation_steps=grad_accumulation_steps,
65
+ **accelerate_kwargs,
66
+ )
67
 
 
68
  if exists(wandb_resume_id):
69
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
70
  else:
71
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
72
+
73
  self.accelerator.init_trackers(
74
  project_name=wandb_project,
75
  init_kwargs=init_kwargs,
 
86
  "noise_scheduler": noise_scheduler,
87
  },
88
  )
89
+ elif self.logger == "tensorboard":
90
+ from torch.utils.tensorboard import SummaryWriter
91
+
92
+ self.accelerator = Accelerator(
93
+ kwargs_handlers=[ddp_kwargs],
94
+ gradient_accumulation_steps=grad_accumulation_steps,
95
+ **accelerate_kwargs,
96
+ )
97
+ if self.is_main:
98
+ path_log_dir = os.path.join(log_dir, wandb_project)
99
+ os.makedirs(path_log_dir, exist_ok=True)
100
+ existing_folders = [folder for folder in os.listdir(path_log_dir) if folder.startswith("exp")]
101
+ next_number = len(existing_folders) + 2
102
+ folder_name = f"exp{next_number}"
103
+ folder_path = os.path.join(path_log_dir, folder_name)
104
+ os.makedirs(folder_path, exist_ok=True)
105
+
106
+ self.writer = SummaryWriter(log_dir=folder_path)
107
 
108
  self.model = model
109
 
110
  if self.is_main:
111
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
 
112
  self.ema_model.to(self.accelerator.device)
113
 
114
  self.epochs = epochs
 
198
  gc.collect()
199
  return step
200
 
201
+ def log(self, metrics, step):
202
+ """Unified logging method for both WandB and TensorBoard"""
203
+ if self.logger == "none":
204
+ return
205
+ if self.logger == "wandb":
206
+ self.accelerator.log(metrics, step=step)
207
+ elif self.is_main:
208
+ for key, value in metrics.items():
209
+ self.writer.add_scalar(key, value, step)
210
+
211
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
212
+ # import only when export_sample True
213
+ if self.export_samples:
214
+ from f5_tts.infer.utils_infer import (
215
+ target_sample_rate,
216
+ hop_length,
217
+ nfe_step,
218
+ cfg_strength,
219
+ sway_sampling_coef,
220
+ vocos,
221
+ )
222
+ from f5_tts.model.utils import get_sample
223
+
224
+ self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
225
+ os.makedirs(self.file_path_samples, exist_ok=True)
226
+
227
  if exists(resumable_with_seed):
228
  generator = torch.Generator()
229
  generator.manual_seed(resumable_with_seed)
 
307
  for batch in progress_bar:
308
  with self.accelerator.accumulate(self.model):
309
  text_inputs = batch["text"]
310
+
311
  mel_spec = batch["mel"].permute(0, 2, 1)
312
  mel_lengths = batch["mel_lengths"]
313
 
 
319
  loss, cond, pred = self.model(
320
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
321
  )
322
+
323
+ # save 4 audio per save step
324
+ if (
325
+ self.accelerator.is_local_main_process
326
+ and self.export_samples
327
+ and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
328
+ ):
329
+ try:
330
+ wave_org, wave_gen, mel_org, mel_gen = get_sample(
331
+ vocos,
332
+ self.model,
333
+ self.file_path_samples,
334
+ global_step,
335
+ batch["mel"][0],
336
+ text_inputs,
337
+ target_sample_rate,
338
+ hop_length,
339
+ nfe_step,
340
+ cfg_strength,
341
+ sway_sampling_coef,
342
+ )
343
+
344
+ if self.logger == "tensorboard":
345
+ self.writer.add_audio(
346
+ "Audio/original", wave_org, global_step, sample_rate=target_sample_rate
347
+ )
348
+ self.writer.add_audio(
349
+ "Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
350
+ )
351
+ self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
352
+ self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")
353
+ except Exception as e:
354
+ print("An error occurred:", e)
355
+
356
  self.accelerator.backward(loss)
357
 
358
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
 
368
  global_step += 1
369
 
370
  if self.accelerator.is_local_main_process:
371
+ self.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
372
 
373
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
374
 
src/f5_tts/model/utils.py CHANGED
@@ -11,6 +11,10 @@ from torch.nn.utils.rnn import pad_sequence
11
  import jieba
12
  from pypinyin import lazy_pinyin, Style
13
 
 
 
 
 
14
 
15
  # seed everything
16
 
@@ -183,3 +187,73 @@ def repetition_found(text, length=2, tolerance=10):
183
  if count > tolerance:
184
  return True
185
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import jieba
12
  from pypinyin import lazy_pinyin, Style
13
 
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import soundfile as sf
17
+ import torchaudio
18
 
19
  # seed everything
20
 
 
187
  if count > tolerance:
188
  return True
189
  return False
190
+
191
+
192
+ def normalize_and_colorize_spectrogram(mel_org):
193
+ mel_min, mel_max = mel_org.min(), mel_org.max()
194
+ mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8)
195
+ mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
196
+ mel_colored = np.transpose(mel_colored, (2, 0, 1))
197
+ return mel_colored
198
+
199
+
200
+ def export_audio(file_out, wav, target_sample_rate):
201
+ sf.write(file_out, wav, samplerate=target_sample_rate)
202
+
203
+
204
+ def export_mel(mel_colored_hwc, file_out):
205
+ plt.imsave(file_out, mel_colored_hwc)
206
+
207
+
208
+ def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
209
+ audio, sr = torchaudio.load(file_wav_org)
210
+ audio = audio.to("cuda")
211
+ ref_audio_len = audio.shape[-1] // hop_length
212
+ text = [text_inputs[0] + [" . "] + text_inputs[0]]
213
+ duration = int((audio.shape[1] / 256) * 2.0)
214
+ with torch.inference_mode():
215
+ generated_gen, _ = model.sample(
216
+ cond=audio,
217
+ text=text,
218
+ duration=duration,
219
+ steps=nfe_step,
220
+ cfg_strength=cfg_strength,
221
+ sway_sampling_coef=sway_sampling_coef,
222
+ )
223
+ generated_gen = generated_gen.to(torch.float32)
224
+ generated_gen = generated_gen[:, ref_audio_len:, :]
225
+ generated_mel_spec_gen = generated_gen.permute(0, 2, 1)
226
+ generated_wave_gen = vocos.decode(generated_mel_spec_gen.cpu())
227
+ generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy()
228
+ return generated_wave_gen, generated_mel_spec_gen
229
+
230
+
231
+ def get_sample(
232
+ vocos,
233
+ model,
234
+ file_path_samples,
235
+ global_step,
236
+ mel_org,
237
+ text_inputs,
238
+ target_sample_rate,
239
+ hop_length,
240
+ nfe_step,
241
+ cfg_strength,
242
+ sway_sampling_coef,
243
+ ):
244
+ generated_wave_org = vocos.decode(mel_org.unsqueeze(0).cpu())
245
+ generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
246
+ file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
247
+ export_audio(file_wav_org, generated_wave_org, target_sample_rate)
248
+ generated_wave_gen, generated_mel_spec_gen = gen_sample(
249
+ model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
250
+ )
251
+ file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
252
+ export_audio(file_wav_gen, generated_wave_gen, target_sample_rate)
253
+ mel_org = normalize_and_colorize_spectrogram(mel_org)
254
+ mel_gen = normalize_and_colorize_spectrogram(generated_mel_spec_gen[0])
255
+ file_gen_org = os.path.join(file_path_samples, f"step_{global_step}_org.png")
256
+ export_mel(np.transpose(mel_org, (1, 2, 0)), file_gen_org)
257
+ file_gen_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.png")
258
+ export_mel(np.transpose(mel_gen, (1, 2, 0)), file_gen_gen)
259
+ return generated_wave_org, generated_wave_gen, mel_org, mel_gen
src/f5_tts/train/finetune_cli.py CHANGED
@@ -56,6 +56,14 @@ def parse_args():
56
  help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
57
  )
58
 
 
 
 
 
 
 
 
 
59
  return parser.parse_args()
60
 
61
 
@@ -64,6 +72,7 @@ def parse_args():
64
 
65
  def main():
66
  args = parse_args()
 
67
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
68
 
69
  # Model parameters based on experiment name
@@ -136,6 +145,8 @@ def main():
136
  wandb_run_name=args.exp_name,
137
  wandb_resume_id=wandb_resume_id,
138
  last_per_steps=args.last_per_steps,
 
 
139
  )
140
 
141
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
 
56
  help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
57
  )
58
 
59
+ parser.add_argument(
60
+ "--export_samples",
61
+ type=bool,
62
+ default=False,
63
+ help="Export 4 audio and spect samples for the checkpoint audio, per step.",
64
+ )
65
+ parser.add_argument("--logger", type=str, default="wandb", choices=["none", "wandb", "tensorboard"], help="logger")
66
+
67
  return parser.parse_args()
68
 
69
 
 
72
 
73
  def main():
74
  args = parse_args()
75
+
76
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
77
 
78
  # Model parameters based on experiment name
 
145
  wandb_run_name=args.exp_name,
146
  wandb_resume_id=wandb_resume_id,
147
  last_per_steps=args.last_per_steps,
148
+ logger=args.logger,
149
+ export_samples=args.export_samples,
150
  )
151
 
152
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -69,6 +69,7 @@ def save_settings(
69
  tokenizer_type,
70
  tokenizer_file,
71
  mixed_precision,
 
72
  ):
73
  path_project = os.path.join(path_project_ckpts, project_name)
74
  os.makedirs(path_project, exist_ok=True)
@@ -91,6 +92,7 @@ def save_settings(
91
  "tokenizer_type": tokenizer_type,
92
  "tokenizer_file": tokenizer_file,
93
  "mixed_precision": mixed_precision,
 
94
  }
95
  with open(file_setting, "w") as f:
96
  json.dump(settings, f, indent=4)
@@ -121,6 +123,7 @@ def load_settings(project_name):
121
  "tokenizer_type": "pinyin",
122
  "tokenizer_file": "",
123
  "mixed_precision": "none",
 
124
  }
125
  return (
126
  settings["exp_name"],
@@ -139,6 +142,7 @@ def load_settings(project_name):
139
  settings["tokenizer_type"],
140
  settings["tokenizer_file"],
141
  settings["mixed_precision"],
 
142
  )
143
 
144
  with open(file_setting, "r") as f:
@@ -160,6 +164,7 @@ def load_settings(project_name):
160
  settings["tokenizer_type"],
161
  settings["tokenizer_file"],
162
  settings["mixed_precision"],
 
163
  )
164
 
165
 
@@ -374,6 +379,7 @@ def start_training(
374
  tokenizer_file="",
375
  mixed_precision="fp16",
376
  stream=False,
 
377
  ):
378
  global training_process, tts_api, stop_signal
379
 
@@ -447,6 +453,8 @@ def start_training(
447
 
448
  cmd += f" --tokenizer {tokenizer_type} "
449
 
 
 
450
  print(cmd)
451
 
452
  save_settings(
@@ -467,6 +475,7 @@ def start_training(
467
  tokenizer_type,
468
  tokenizer_file,
469
  mixed_precision,
 
470
  )
471
 
472
  try:
@@ -1223,6 +1232,27 @@ def get_checkpoints_project(project_name, is_gradio=True):
1223
  return files_checkpoints, selelect_checkpoint
1224
 
1225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1226
  def get_gpu_stats():
1227
  gpu_stats = ""
1228
 
@@ -1290,6 +1320,21 @@ def get_combined_stats():
1290
  return combined_stats
1291
 
1292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1293
  with gr.Blocks() as app:
1294
  gr.Markdown(
1295
  """
@@ -1470,6 +1515,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1470
 
1471
  with gr.Row():
1472
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
 
1473
  start_button = gr.Button("Start Training")
1474
  stop_button = gr.Button("Stop Training", interactive=False)
1475
 
@@ -1491,6 +1537,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1491
  tokenizer_typev,
1492
  tokenizer_filev,
1493
  mixed_precisionv,
 
1494
  ) = load_settings(projects_selelect)
1495
  exp_name.value = exp_namev
1496
  learning_rate.value = learning_ratev
@@ -1508,9 +1555,51 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1508
  tokenizer_type.value = tokenizer_typev
1509
  tokenizer_file.value = tokenizer_filev
1510
  mixed_precision.value = mixed_precisionv
 
1511
 
1512
  ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
1513
  txt_info_train = gr.Text(label="info", value="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1514
  start_button.click(
1515
  fn=start_training,
1516
  inputs=[
@@ -1532,6 +1621,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1532
  tokenizer_file,
1533
  mixed_precision,
1534
  ch_stream,
 
1535
  ],
1536
  outputs=[txt_info_train, start_button, stop_button],
1537
  )
@@ -1583,6 +1673,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1583
  tokenizer_type,
1584
  tokenizer_file,
1585
  mixed_precision,
 
1586
  ]
1587
 
1588
  return output_components
 
69
  tokenizer_type,
70
  tokenizer_file,
71
  mixed_precision,
72
+ logger,
73
  ):
74
  path_project = os.path.join(path_project_ckpts, project_name)
75
  os.makedirs(path_project, exist_ok=True)
 
92
  "tokenizer_type": tokenizer_type,
93
  "tokenizer_file": tokenizer_file,
94
  "mixed_precision": mixed_precision,
95
+ "logger": logger,
96
  }
97
  with open(file_setting, "w") as f:
98
  json.dump(settings, f, indent=4)
 
123
  "tokenizer_type": "pinyin",
124
  "tokenizer_file": "",
125
  "mixed_precision": "none",
126
+ "logger": "wandb",
127
  }
128
  return (
129
  settings["exp_name"],
 
142
  settings["tokenizer_type"],
143
  settings["tokenizer_file"],
144
  settings["mixed_precision"],
145
+ settings["logger"],
146
  )
147
 
148
  with open(file_setting, "r") as f:
 
164
  settings["tokenizer_type"],
165
  settings["tokenizer_file"],
166
  settings["mixed_precision"],
167
+ settings["logger"],
168
  )
169
 
170
 
 
379
  tokenizer_file="",
380
  mixed_precision="fp16",
381
  stream=False,
382
+ logger="wandb",
383
  ):
384
  global training_process, tts_api, stop_signal
385
 
 
453
 
454
  cmd += f" --tokenizer {tokenizer_type} "
455
 
456
+ cmd += f" --export_samples True --logger {logger} "
457
+
458
  print(cmd)
459
 
460
  save_settings(
 
475
  tokenizer_type,
476
  tokenizer_file,
477
  mixed_precision,
478
+ logger,
479
  )
480
 
481
  try:
 
1232
  return files_checkpoints, selelect_checkpoint
1233
 
1234
 
1235
+ def get_audio_project(project_name, is_gradio=True):
1236
+ if project_name is None:
1237
+ return [], ""
1238
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
1239
+
1240
+ if os.path.isdir(path_project_ckpts):
1241
+ files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
1242
+ files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
1243
+
1244
+ files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
1245
+ else:
1246
+ files_audios = []
1247
+
1248
+ selelect_checkpoint = None if not files_audios else files_audios[0]
1249
+
1250
+ if is_gradio:
1251
+ return gr.update(choices=files_audios, value=selelect_checkpoint)
1252
+
1253
+ return files_audios, selelect_checkpoint
1254
+
1255
+
1256
  def get_gpu_stats():
1257
  gpu_stats = ""
1258
 
 
1320
  return combined_stats
1321
 
1322
 
1323
+ def get_audio_select(file_sample):
1324
+ select_audio_org = file_sample
1325
+ select_audio_gen = file_sample
1326
+ select_image_org = file_sample
1327
+ select_image_gen = file_sample
1328
+
1329
+ if file_sample is not None:
1330
+ select_audio_org += "_org.wav"
1331
+ select_audio_gen += "_gen.wav"
1332
+ select_image_org += "_org.png"
1333
+ select_image_gen += "_gen.png"
1334
+
1335
+ return select_audio_org, select_audio_gen, select_image_org, select_image_gen
1336
+
1337
+
1338
  with gr.Blocks() as app:
1339
  gr.Markdown(
1340
  """
 
1515
 
1516
  with gr.Row():
1517
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
1518
+ cd_logger = gr.Radio(label="logger", choices=["none", "wandb", "tensorboard"], value="wandb")
1519
  start_button = gr.Button("Start Training")
1520
  stop_button = gr.Button("Stop Training", interactive=False)
1521
 
 
1537
  tokenizer_typev,
1538
  tokenizer_filev,
1539
  mixed_precisionv,
1540
+ cd_loggerv,
1541
  ) = load_settings(projects_selelect)
1542
  exp_name.value = exp_namev
1543
  learning_rate.value = learning_ratev
 
1555
  tokenizer_type.value = tokenizer_typev
1556
  tokenizer_file.value = tokenizer_filev
1557
  mixed_precision.value = mixed_precisionv
1558
+ cd_logger.value = cd_loggerv
1559
 
1560
  ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
1561
  txt_info_train = gr.Text(label="info", value="")
1562
+
1563
+ list_audios, select_audio = get_audio_project(projects_selelect, False)
1564
+
1565
+ select_audio_org = select_audio
1566
+ select_audio_gen = select_audio
1567
+ select_image_org = select_audio
1568
+ select_image_gen = select_audio
1569
+
1570
+ if select_audio is not None:
1571
+ select_audio_org += "_org.wav"
1572
+ select_audio_gen += "_gen.wav"
1573
+ select_image_org += "_org.png"
1574
+ select_image_gen += "_gen.png"
1575
+
1576
+ with gr.Row():
1577
+ ch_list_audio = gr.Dropdown(
1578
+ choices=list_audios,
1579
+ value=select_audio,
1580
+ label="audios",
1581
+ allow_custom_value=True,
1582
+ scale=6,
1583
+ interactive=True,
1584
+ )
1585
+ bt_stream_audio = gr.Button("refresh", scale=1)
1586
+ bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1587
+ cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1588
+
1589
+ with gr.Row():
1590
+ audio_org_stream = gr.Audio(label="original", type="filepath", value=select_audio_org)
1591
+ mel_org_stream = gr.Image(label="original", type="filepath", value=select_image_org)
1592
+
1593
+ with gr.Row():
1594
+ audio_gen_stream = gr.Audio(label="generate", type="filepath", value=select_audio_gen)
1595
+ mel_gen_stream = gr.Image(label="generate", type="filepath", value=select_image_gen)
1596
+
1597
+ ch_list_audio.change(
1598
+ fn=get_audio_select,
1599
+ inputs=[ch_list_audio],
1600
+ outputs=[audio_org_stream, audio_gen_stream, mel_org_stream, mel_gen_stream],
1601
+ )
1602
+
1603
  start_button.click(
1604
  fn=start_training,
1605
  inputs=[
 
1621
  tokenizer_file,
1622
  mixed_precision,
1623
  ch_stream,
1624
+ cd_logger,
1625
  ],
1626
  outputs=[txt_info_train, start_button, stop_button],
1627
  )
 
1673
  tokenizer_type,
1674
  tokenizer_file,
1675
  mixed_precision,
1676
+ cd_logger,
1677
  ]
1678
 
1679
  return output_components