Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- app.py +2 -1
 - src/f5_tts/api.py +2 -1
 - src/f5_tts/infer/infer_cli.py +4 -2
 - src/f5_tts/infer/utils_infer.py +4 -4
 - src/f5_tts/model/trainer.py +42 -4
 - src/f5_tts/train/finetune_cli.py +11 -0
 - src/f5_tts/train/finetune_gradio.py +79 -0
 - src/f5_tts/train/train.py +1 -0
 
    	
        app.py
    CHANGED
    
    | 
         @@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import ( 
     | 
|
| 37 | 
         
             
                save_spectrogram,
         
     | 
| 38 | 
         
             
            )
         
     | 
| 39 | 
         | 
| 40 | 
         
            -
             
     | 
| 41 | 
         | 
| 42 | 
         | 
| 43 | 
         
             
            # load models
         
     | 
| 
         @@ -94,6 +94,7 @@ def infer( 
     | 
|
| 94 | 
         
             
                    ref_text,
         
     | 
| 95 | 
         
             
                    gen_text,
         
     | 
| 96 | 
         
             
                    ema_model,
         
     | 
| 
         | 
|
| 97 | 
         
             
                    cross_fade_duration=cross_fade_duration,
         
     | 
| 98 | 
         
             
                    speed=speed,
         
     | 
| 99 | 
         
             
                    show_info=show_info,
         
     | 
| 
         | 
|
| 37 | 
         
             
                save_spectrogram,
         
     | 
| 38 | 
         
             
            )
         
     | 
| 39 | 
         | 
| 40 | 
         
            +
            vocoder = load_vocoder()
         
     | 
| 41 | 
         | 
| 42 | 
         | 
| 43 | 
         
             
            # load models
         
     | 
| 
         | 
|
| 94 | 
         
             
                    ref_text,
         
     | 
| 95 | 
         
             
                    gen_text,
         
     | 
| 96 | 
         
             
                    ema_model,
         
     | 
| 97 | 
         
            +
                    vocoder,
         
     | 
| 98 | 
         
             
                    cross_fade_duration=cross_fade_duration,
         
     | 
| 99 | 
         
             
                    speed=speed,
         
     | 
| 100 | 
         
             
                    show_info=show_info,
         
     | 
    	
        src/f5_tts/api.py
    CHANGED
    
    | 
         @@ -47,7 +47,7 @@ class F5TTS: 
     | 
|
| 47 | 
         
             
                    self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
         
     | 
| 48 | 
         | 
| 49 | 
         
             
                def load_vocoder_model(self, local_path):
         
     | 
| 50 | 
         
            -
                    self. 
     | 
| 51 | 
         | 
| 52 | 
         
             
                def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
         
     | 
| 53 | 
         
             
                    if model_type == "F5-TTS":
         
     | 
| 
         @@ -102,6 +102,7 @@ class F5TTS: 
     | 
|
| 102 | 
         
             
                        ref_text,
         
     | 
| 103 | 
         
             
                        gen_text,
         
     | 
| 104 | 
         
             
                        self.ema_model,
         
     | 
| 
         | 
|
| 105 | 
         
             
                        show_info=show_info,
         
     | 
| 106 | 
         
             
                        progress=progress,
         
     | 
| 107 | 
         
             
                        target_rms=target_rms,
         
     | 
| 
         | 
|
| 47 | 
         
             
                    self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
         
     | 
| 48 | 
         | 
| 49 | 
         
             
                def load_vocoder_model(self, local_path):
         
     | 
| 50 | 
         
            +
                    self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
         
     | 
| 51 | 
         | 
| 52 | 
         
             
                def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
         
     | 
| 53 | 
         
             
                    if model_type == "F5-TTS":
         
     | 
| 
         | 
|
| 102 | 
         
             
                        ref_text,
         
     | 
| 103 | 
         
             
                        gen_text,
         
     | 
| 104 | 
         
             
                        self.ema_model,
         
     | 
| 105 | 
         
            +
                        self.vocoder,
         
     | 
| 106 | 
         
             
                        show_info=show_info,
         
     | 
| 107 | 
         
             
                        progress=progress,
         
     | 
| 108 | 
         
             
                        target_rms=target_rms,
         
     | 
    	
        src/f5_tts/infer/infer_cli.py
    CHANGED
    
    | 
         @@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav" 
     | 
|
| 113 | 
         
             
            # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
         
     | 
| 114 | 
         
             
            vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
         
     | 
| 115 | 
         | 
| 116 | 
         
            -
             
     | 
| 117 | 
         | 
| 118 | 
         | 
| 119 | 
         
             
            # load models
         
     | 
| 
         @@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed 
     | 
|
| 175 | 
         
             
                    ref_audio = voices[voice]["ref_audio"]
         
     | 
| 176 | 
         
             
                    ref_text = voices[voice]["ref_text"]
         
     | 
| 177 | 
         
             
                    print(f"Voice: {voice}")
         
     | 
| 178 | 
         
            -
                    audio, final_sample_rate, spectragram = infer_process( 
     | 
| 
         | 
|
| 
         | 
|
| 179 | 
         
             
                    generated_audio_segments.append(audio)
         
     | 
| 180 | 
         | 
| 181 | 
         
             
                if generated_audio_segments:
         
     | 
| 
         | 
|
| 113 | 
         
             
            # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
         
     | 
| 114 | 
         
             
            vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
         
     | 
| 115 | 
         | 
| 116 | 
         
            +
            vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
         
     | 
| 117 | 
         | 
| 118 | 
         | 
| 119 | 
         
             
            # load models
         
     | 
| 
         | 
|
| 175 | 
         
             
                    ref_audio = voices[voice]["ref_audio"]
         
     | 
| 176 | 
         
             
                    ref_text = voices[voice]["ref_text"]
         
     | 
| 177 | 
         
             
                    print(f"Voice: {voice}")
         
     | 
| 178 | 
         
            +
                    audio, final_sample_rate, spectragram = infer_process(
         
     | 
| 179 | 
         
            +
                        ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
         
     | 
| 180 | 
         
            +
                    )
         
     | 
| 181 | 
         
             
                    generated_audio_segments.append(audio)
         
     | 
| 182 | 
         | 
| 183 | 
         
             
                if generated_audio_segments:
         
     | 
    	
        src/f5_tts/infer/utils_infer.py
    CHANGED
    
    | 
         @@ -29,9 +29,6 @@ _ref_audio_cache = {} 
     | 
|
| 29 | 
         | 
| 30 | 
         
             
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
         
     | 
| 31 | 
         | 
| 32 | 
         
            -
            vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
             
            # -----------------------------------------
         
     | 
| 36 | 
         | 
| 37 | 
         
             
            target_sample_rate = 24000
         
     | 
| 
         @@ -263,6 +260,7 @@ def infer_process( 
     | 
|
| 263 | 
         
             
                ref_text,
         
     | 
| 264 | 
         
             
                gen_text,
         
     | 
| 265 | 
         
             
                model_obj,
         
     | 
| 
         | 
|
| 266 | 
         
             
                show_info=print,
         
     | 
| 267 | 
         
             
                progress=tqdm,
         
     | 
| 268 | 
         
             
                target_rms=target_rms,
         
     | 
| 
         @@ -287,6 +285,7 @@ def infer_process( 
     | 
|
| 287 | 
         
             
                    ref_text,
         
     | 
| 288 | 
         
             
                    gen_text_batches,
         
     | 
| 289 | 
         
             
                    model_obj,
         
     | 
| 
         | 
|
| 290 | 
         
             
                    progress=progress,
         
     | 
| 291 | 
         
             
                    target_rms=target_rms,
         
     | 
| 292 | 
         
             
                    cross_fade_duration=cross_fade_duration,
         
     | 
| 
         @@ -307,6 +306,7 @@ def infer_batch_process( 
     | 
|
| 307 | 
         
             
                ref_text,
         
     | 
| 308 | 
         
             
                gen_text_batches,
         
     | 
| 309 | 
         
             
                model_obj,
         
     | 
| 
         | 
|
| 310 | 
         
             
                progress=tqdm,
         
     | 
| 311 | 
         
             
                target_rms=0.1,
         
     | 
| 312 | 
         
             
                cross_fade_duration=0.15,
         
     | 
| 
         @@ -362,7 +362,7 @@ def infer_batch_process( 
     | 
|
| 362 | 
         
             
                    generated = generated.to(torch.float32)
         
     | 
| 363 | 
         
             
                    generated = generated[:, ref_audio_len:, :]
         
     | 
| 364 | 
         
             
                    generated_mel_spec = generated.permute(0, 2, 1)
         
     | 
| 365 | 
         
            -
                    generated_wave =  
     | 
| 366 | 
         
             
                    if rms < target_rms:
         
     | 
| 367 | 
         
             
                        generated_wave = generated_wave * rms / target_rms
         
     | 
| 368 | 
         | 
| 
         | 
|
| 29 | 
         | 
| 30 | 
         
             
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
         
     | 
| 31 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 32 | 
         
             
            # -----------------------------------------
         
     | 
| 33 | 
         | 
| 34 | 
         
             
            target_sample_rate = 24000
         
     | 
| 
         | 
|
| 260 | 
         
             
                ref_text,
         
     | 
| 261 | 
         
             
                gen_text,
         
     | 
| 262 | 
         
             
                model_obj,
         
     | 
| 263 | 
         
            +
                vocoder,
         
     | 
| 264 | 
         
             
                show_info=print,
         
     | 
| 265 | 
         
             
                progress=tqdm,
         
     | 
| 266 | 
         
             
                target_rms=target_rms,
         
     | 
| 
         | 
|
| 285 | 
         
             
                    ref_text,
         
     | 
| 286 | 
         
             
                    gen_text_batches,
         
     | 
| 287 | 
         
             
                    model_obj,
         
     | 
| 288 | 
         
            +
                    vocoder,
         
     | 
| 289 | 
         
             
                    progress=progress,
         
     | 
| 290 | 
         
             
                    target_rms=target_rms,
         
     | 
| 291 | 
         
             
                    cross_fade_duration=cross_fade_duration,
         
     | 
| 
         | 
|
| 306 | 
         
             
                ref_text,
         
     | 
| 307 | 
         
             
                gen_text_batches,
         
     | 
| 308 | 
         
             
                model_obj,
         
     | 
| 309 | 
         
            +
                vocoder,
         
     | 
| 310 | 
         
             
                progress=tqdm,
         
     | 
| 311 | 
         
             
                target_rms=0.1,
         
     | 
| 312 | 
         
             
                cross_fade_duration=0.15,
         
     | 
| 
         | 
|
| 362 | 
         
             
                    generated = generated.to(torch.float32)
         
     | 
| 363 | 
         
             
                    generated = generated[:, ref_audio_len:, :]
         
     | 
| 364 | 
         
             
                    generated_mel_spec = generated.permute(0, 2, 1)
         
     | 
| 365 | 
         
            +
                    generated_wave = vocoder.decode(generated_mel_spec.cpu())
         
     | 
| 366 | 
         
             
                    if rms < target_rms:
         
     | 
| 367 | 
         
             
                        generated_wave = generated_wave * rms / target_rms
         
     | 
| 368 | 
         | 
    	
        src/f5_tts/model/trainer.py
    CHANGED
    
    | 
         @@ -6,6 +6,7 @@ from tqdm import tqdm 
     | 
|
| 6 | 
         
             
            import wandb
         
     | 
| 7 | 
         | 
| 8 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 9 | 
         
             
            from torch.optim import AdamW
         
     | 
| 10 | 
         
             
            from torch.utils.data import DataLoader, Dataset, SequentialSampler
         
     | 
| 11 | 
         
             
            from torch.optim.lr_scheduler import LinearLR, SequentialLR
         
     | 
| 
         @@ -39,9 +40,11 @@ class Trainer: 
     | 
|
| 39 | 
         
             
                    max_grad_norm=1.0,
         
     | 
| 40 | 
         
             
                    noise_scheduler: str | None = None,
         
     | 
| 41 | 
         
             
                    duration_predictor: torch.nn.Module | None = None,
         
     | 
| 
         | 
|
| 42 | 
         
             
                    wandb_project="test_e2-tts",
         
     | 
| 43 | 
         
             
                    wandb_run_name="test_run",
         
     | 
| 44 | 
         
             
                    wandb_resume_id: str = None,
         
     | 
| 
         | 
|
| 45 | 
         
             
                    last_per_steps=None,
         
     | 
| 46 | 
         
             
                    accelerate_kwargs: dict = dict(),
         
     | 
| 47 | 
         
             
                    ema_kwargs: dict = dict(),
         
     | 
| 
         @@ -49,21 +52,25 @@ class Trainer: 
     | 
|
| 49 | 
         
             
                ):
         
     | 
| 50 | 
         
             
                    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
         
     | 
| 51 | 
         | 
| 52 | 
         
            -
                    logger  
     | 
| 
         | 
|
| 53 | 
         
             
                    print(f"Using logger: {logger}")
         
     | 
| 
         | 
|
| 54 | 
         | 
| 55 | 
         
             
                    self.accelerator = Accelerator(
         
     | 
| 56 | 
         
            -
                        log_with=logger,
         
     | 
| 57 | 
         
             
                        kwargs_handlers=[ddp_kwargs],
         
     | 
| 58 | 
         
             
                        gradient_accumulation_steps=grad_accumulation_steps,
         
     | 
| 59 | 
         
             
                        **accelerate_kwargs,
         
     | 
| 60 | 
         
             
                    )
         
     | 
| 61 | 
         | 
| 62 | 
         
            -
                     
     | 
| 
         | 
|
| 63 | 
         
             
                        if exists(wandb_resume_id):
         
     | 
| 64 | 
         
             
                            init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
         
     | 
| 65 | 
         
             
                        else:
         
     | 
| 66 | 
         
             
                            init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
         
     | 
| 
         | 
|
| 67 | 
         
             
                        self.accelerator.init_trackers(
         
     | 
| 68 | 
         
             
                            project_name=wandb_project,
         
     | 
| 69 | 
         
             
                            init_kwargs=init_kwargs,
         
     | 
| 
         @@ -81,11 +88,15 @@ class Trainer: 
     | 
|
| 81 | 
         
             
                            },
         
     | 
| 82 | 
         
             
                        )
         
     | 
| 83 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 84 | 
         
             
                    self.model = model
         
     | 
| 85 | 
         | 
| 86 | 
         
             
                    if self.is_main:
         
     | 
| 87 | 
         
             
                        self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
             
                        self.ema_model.to(self.accelerator.device)
         
     | 
| 90 | 
         | 
| 91 | 
         
             
                    self.epochs = epochs
         
     | 
| 
         @@ -176,6 +187,14 @@ class Trainer: 
     | 
|
| 176 | 
         
             
                    return step
         
     | 
| 177 | 
         | 
| 178 | 
         
             
                def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 179 | 
         
             
                    if exists(resumable_with_seed):
         
     | 
| 180 | 
         
             
                        generator = torch.Generator()
         
     | 
| 181 | 
         
             
                        generator.manual_seed(resumable_with_seed)
         
     | 
| 
         @@ -286,12 +305,31 @@ class Trainer: 
     | 
|
| 286 | 
         | 
| 287 | 
         
             
                            if self.accelerator.is_local_main_process:
         
     | 
| 288 | 
         
             
                                self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 289 | 
         | 
| 290 | 
         
             
                            progress_bar.set_postfix(step=str(global_step), loss=loss.item())
         
     | 
| 291 | 
         | 
| 292 | 
         
             
                            if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
         
     | 
| 293 | 
         
             
                                self.save_checkpoint(global_step)
         
     | 
| 294 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 295 | 
         
             
                            if global_step % self.last_per_steps == 0:
         
     | 
| 296 | 
         
             
                                self.save_checkpoint(global_step, last=True)
         
     | 
| 297 | 
         | 
| 
         | 
|
| 6 | 
         
             
            import wandb
         
     | 
| 7 | 
         | 
| 8 | 
         
             
            import torch
         
     | 
| 9 | 
         
            +
            import torchaudio
         
     | 
| 10 | 
         
             
            from torch.optim import AdamW
         
     | 
| 11 | 
         
             
            from torch.utils.data import DataLoader, Dataset, SequentialSampler
         
     | 
| 12 | 
         
             
            from torch.optim.lr_scheduler import LinearLR, SequentialLR
         
     | 
| 
         | 
|
| 40 | 
         
             
                    max_grad_norm=1.0,
         
     | 
| 41 | 
         
             
                    noise_scheduler: str | None = None,
         
     | 
| 42 | 
         
             
                    duration_predictor: torch.nn.Module | None = None,
         
     | 
| 43 | 
         
            +
                    logger: str | None = "wandb",  # "wandb" | "tensorboard" | None
         
     | 
| 44 | 
         
             
                    wandb_project="test_e2-tts",
         
     | 
| 45 | 
         
             
                    wandb_run_name="test_run",
         
     | 
| 46 | 
         
             
                    wandb_resume_id: str = None,
         
     | 
| 47 | 
         
            +
                    log_samples: bool = False,
         
     | 
| 48 | 
         
             
                    last_per_steps=None,
         
     | 
| 49 | 
         
             
                    accelerate_kwargs: dict = dict(),
         
     | 
| 50 | 
         
             
                    ema_kwargs: dict = dict(),
         
     | 
| 
         | 
|
| 52 | 
         
             
                ):
         
     | 
| 53 | 
         
             
                    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
         
     | 
| 54 | 
         | 
| 55 | 
         
            +
                    if logger == "wandb" and not wandb.api.api_key:
         
     | 
| 56 | 
         
            +
                        logger = None
         
     | 
| 57 | 
         
             
                    print(f"Using logger: {logger}")
         
     | 
| 58 | 
         
            +
                    self.log_samples = log_samples
         
     | 
| 59 | 
         | 
| 60 | 
         
             
                    self.accelerator = Accelerator(
         
     | 
| 61 | 
         
            +
                        log_with=logger if logger == "wandb" else None,
         
     | 
| 62 | 
         
             
                        kwargs_handlers=[ddp_kwargs],
         
     | 
| 63 | 
         
             
                        gradient_accumulation_steps=grad_accumulation_steps,
         
     | 
| 64 | 
         
             
                        **accelerate_kwargs,
         
     | 
| 65 | 
         
             
                    )
         
     | 
| 66 | 
         | 
| 67 | 
         
            +
                    self.logger = logger
         
     | 
| 68 | 
         
            +
                    if self.logger == "wandb":
         
     | 
| 69 | 
         
             
                        if exists(wandb_resume_id):
         
     | 
| 70 | 
         
             
                            init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
         
     | 
| 71 | 
         
             
                        else:
         
     | 
| 72 | 
         
             
                            init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
             
                        self.accelerator.init_trackers(
         
     | 
| 75 | 
         
             
                            project_name=wandb_project,
         
     | 
| 76 | 
         
             
                            init_kwargs=init_kwargs,
         
     | 
| 
         | 
|
| 88 | 
         
             
                            },
         
     | 
| 89 | 
         
             
                        )
         
     | 
| 90 | 
         | 
| 91 | 
         
            +
                    elif self.logger == "tensorboard":
         
     | 
| 92 | 
         
            +
                        from torch.utils.tensorboard import SummaryWriter
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                        self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
             
                    self.model = model
         
     | 
| 97 | 
         | 
| 98 | 
         
             
                    if self.is_main:
         
     | 
| 99 | 
         
             
                        self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
         
     | 
| 
         | 
|
| 100 | 
         
             
                        self.ema_model.to(self.accelerator.device)
         
     | 
| 101 | 
         | 
| 102 | 
         
             
                    self.epochs = epochs
         
     | 
| 
         | 
|
| 187 | 
         
             
                    return step
         
     | 
| 188 | 
         | 
| 189 | 
         
             
                def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
         
     | 
| 190 | 
         
            +
                    if self.log_samples:
         
     | 
| 191 | 
         
            +
                        from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                        vocoder = load_vocoder()
         
     | 
| 194 | 
         
            +
                        target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
         
     | 
| 195 | 
         
            +
                        log_samples_path = f"{self.checkpoint_path}/samples"
         
     | 
| 196 | 
         
            +
                        os.makedirs(log_samples_path, exist_ok=True)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
             
                    if exists(resumable_with_seed):
         
     | 
| 199 | 
         
             
                        generator = torch.Generator()
         
     | 
| 200 | 
         
             
                        generator.manual_seed(resumable_with_seed)
         
     | 
| 
         | 
|
| 305 | 
         | 
| 306 | 
         
             
                            if self.accelerator.is_local_main_process:
         
     | 
| 307 | 
         
             
                                self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
         
     | 
| 308 | 
         
            +
                                if self.logger == "tensorboard":
         
     | 
| 309 | 
         
            +
                                    self.writer.add_scalar("loss", loss.item(), global_step)
         
     | 
| 310 | 
         
            +
                                    self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
         
     | 
| 311 | 
         | 
| 312 | 
         
             
                            progress_bar.set_postfix(step=str(global_step), loss=loss.item())
         
     | 
| 313 | 
         | 
| 314 | 
         
             
                            if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
         
     | 
| 315 | 
         
             
                                self.save_checkpoint(global_step)
         
     | 
| 316 | 
         | 
| 317 | 
         
            +
                                if self.log_samples and self.accelerator.is_local_main_process:
         
     | 
| 318 | 
         
            +
                                    ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0]
         
     | 
| 319 | 
         
            +
                                    torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
         
     | 
| 320 | 
         
            +
                                    with torch.inference_mode():
         
     | 
| 321 | 
         
            +
                                        generated, _ = self.accelerator.unwrap_model(self.model).sample(
         
     | 
| 322 | 
         
            +
                                            cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
         
     | 
| 323 | 
         
            +
                                            text=[text_inputs[0] + [" "] + text_inputs[0]],
         
     | 
| 324 | 
         
            +
                                            duration=ref_audio_len * 2,
         
     | 
| 325 | 
         
            +
                                            steps=nfe_step,
         
     | 
| 326 | 
         
            +
                                            cfg_strength=cfg_strength,
         
     | 
| 327 | 
         
            +
                                            sway_sampling_coef=sway_sampling_coef,
         
     | 
| 328 | 
         
            +
                                        )
         
     | 
| 329 | 
         
            +
                                    generated = generated.to(torch.float32)
         
     | 
| 330 | 
         
            +
                                    gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
         
     | 
| 331 | 
         
            +
                                    torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
             
                            if global_step % self.last_per_steps == 0:
         
     | 
| 334 | 
         
             
                                self.save_checkpoint(global_step, last=True)
         
     | 
| 335 | 
         | 
    	
        src/f5_tts/train/finetune_cli.py
    CHANGED
    
    | 
         @@ -56,6 +56,14 @@ def parse_args(): 
     | 
|
| 56 | 
         
             
                    help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
         
     | 
| 57 | 
         
             
                )
         
     | 
| 58 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 59 | 
         
             
                return parser.parse_args()
         
     | 
| 60 | 
         | 
| 61 | 
         | 
| 
         @@ -64,6 +72,7 @@ def parse_args(): 
     | 
|
| 64 | 
         | 
| 65 | 
         
             
            def main():
         
     | 
| 66 | 
         
             
                args = parse_args()
         
     | 
| 
         | 
|
| 67 | 
         
             
                checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
         
     | 
| 68 | 
         | 
| 69 | 
         
             
                # Model parameters based on experiment name
         
     | 
| 
         @@ -132,9 +141,11 @@ def main(): 
     | 
|
| 132 | 
         
             
                    max_samples=args.max_samples,
         
     | 
| 133 | 
         
             
                    grad_accumulation_steps=args.grad_accumulation_steps,
         
     | 
| 134 | 
         
             
                    max_grad_norm=args.max_grad_norm,
         
     | 
| 
         | 
|
| 135 | 
         
             
                    wandb_project=args.dataset_name,
         
     | 
| 136 | 
         
             
                    wandb_run_name=args.exp_name,
         
     | 
| 137 | 
         
             
                    wandb_resume_id=wandb_resume_id,
         
     | 
| 
         | 
|
| 138 | 
         
             
                    last_per_steps=args.last_per_steps,
         
     | 
| 139 | 
         
             
                )
         
     | 
| 140 | 
         | 
| 
         | 
|
| 56 | 
         
             
                    help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
         
     | 
| 57 | 
         
             
                )
         
     | 
| 58 | 
         | 
| 59 | 
         
            +
                parser.add_argument(
         
     | 
| 60 | 
         
            +
                    "--log_samples",
         
     | 
| 61 | 
         
            +
                    type=bool,
         
     | 
| 62 | 
         
            +
                    default=False,
         
     | 
| 63 | 
         
            +
                    help="Log inferenced samples per ckpt save steps",
         
     | 
| 64 | 
         
            +
                )
         
     | 
| 65 | 
         
            +
                parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
             
                return parser.parse_args()
         
     | 
| 68 | 
         | 
| 69 | 
         | 
| 
         | 
|
| 72 | 
         | 
| 73 | 
         
             
            def main():
         
     | 
| 74 | 
         
             
                args = parse_args()
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
             
                checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
         
     | 
| 77 | 
         | 
| 78 | 
         
             
                # Model parameters based on experiment name
         
     | 
| 
         | 
|
| 141 | 
         
             
                    max_samples=args.max_samples,
         
     | 
| 142 | 
         
             
                    grad_accumulation_steps=args.grad_accumulation_steps,
         
     | 
| 143 | 
         
             
                    max_grad_norm=args.max_grad_norm,
         
     | 
| 144 | 
         
            +
                    logger=args.logger,
         
     | 
| 145 | 
         
             
                    wandb_project=args.dataset_name,
         
     | 
| 146 | 
         
             
                    wandb_run_name=args.exp_name,
         
     | 
| 147 | 
         
             
                    wandb_resume_id=wandb_resume_id,
         
     | 
| 148 | 
         
            +
                    log_samples=args.log_samples,
         
     | 
| 149 | 
         
             
                    last_per_steps=args.last_per_steps,
         
     | 
| 150 | 
         
             
                )
         
     | 
| 151 | 
         | 
    	
        src/f5_tts/train/finetune_gradio.py
    CHANGED
    
    | 
         @@ -69,6 +69,7 @@ def save_settings( 
     | 
|
| 69 | 
         
             
                tokenizer_type,
         
     | 
| 70 | 
         
             
                tokenizer_file,
         
     | 
| 71 | 
         
             
                mixed_precision,
         
     | 
| 
         | 
|
| 72 | 
         
             
            ):
         
     | 
| 73 | 
         
             
                path_project = os.path.join(path_project_ckpts, project_name)
         
     | 
| 74 | 
         
             
                os.makedirs(path_project, exist_ok=True)
         
     | 
| 
         @@ -91,6 +92,7 @@ def save_settings( 
     | 
|
| 91 | 
         
             
                    "tokenizer_type": tokenizer_type,
         
     | 
| 92 | 
         
             
                    "tokenizer_file": tokenizer_file,
         
     | 
| 93 | 
         
             
                    "mixed_precision": mixed_precision,
         
     | 
| 
         | 
|
| 94 | 
         
             
                }
         
     | 
| 95 | 
         
             
                with open(file_setting, "w") as f:
         
     | 
| 96 | 
         
             
                    json.dump(settings, f, indent=4)
         
     | 
| 
         @@ -121,6 +123,7 @@ def load_settings(project_name): 
     | 
|
| 121 | 
         
             
                        "tokenizer_type": "pinyin",
         
     | 
| 122 | 
         
             
                        "tokenizer_file": "",
         
     | 
| 123 | 
         
             
                        "mixed_precision": "none",
         
     | 
| 
         | 
|
| 124 | 
         
             
                    }
         
     | 
| 125 | 
         
             
                    return (
         
     | 
| 126 | 
         
             
                        settings["exp_name"],
         
     | 
| 
         @@ -139,6 +142,7 @@ def load_settings(project_name): 
     | 
|
| 139 | 
         
             
                        settings["tokenizer_type"],
         
     | 
| 140 | 
         
             
                        settings["tokenizer_file"],
         
     | 
| 141 | 
         
             
                        settings["mixed_precision"],
         
     | 
| 
         | 
|
| 142 | 
         
             
                    )
         
     | 
| 143 | 
         | 
| 144 | 
         
             
                with open(file_setting, "r") as f:
         
     | 
| 
         @@ -160,6 +164,7 @@ def load_settings(project_name): 
     | 
|
| 160 | 
         
             
                    settings["tokenizer_type"],
         
     | 
| 161 | 
         
             
                    settings["tokenizer_file"],
         
     | 
| 162 | 
         
             
                    settings["mixed_precision"],
         
     | 
| 
         | 
|
| 163 | 
         
             
                )
         
     | 
| 164 | 
         | 
| 165 | 
         | 
| 
         @@ -374,6 +379,7 @@ def start_training( 
     | 
|
| 374 | 
         
             
                tokenizer_file="",
         
     | 
| 375 | 
         
             
                mixed_precision="fp16",
         
     | 
| 376 | 
         
             
                stream=False,
         
     | 
| 
         | 
|
| 377 | 
         
             
            ):
         
     | 
| 378 | 
         
             
                global training_process, tts_api, stop_signal
         
     | 
| 379 | 
         | 
| 
         @@ -447,6 +453,8 @@ def start_training( 
     | 
|
| 447 | 
         | 
| 448 | 
         
             
                cmd += f" --tokenizer {tokenizer_type} "
         
     | 
| 449 | 
         | 
| 
         | 
|
| 
         | 
|
| 450 | 
         
             
                print(cmd)
         
     | 
| 451 | 
         | 
| 452 | 
         
             
                save_settings(
         
     | 
| 
         @@ -467,6 +475,7 @@ def start_training( 
     | 
|
| 467 | 
         
             
                    tokenizer_type,
         
     | 
| 468 | 
         
             
                    tokenizer_file,
         
     | 
| 469 | 
         
             
                    mixed_precision,
         
     | 
| 
         | 
|
| 470 | 
         
             
                )
         
     | 
| 471 | 
         | 
| 472 | 
         
             
                try:
         
     | 
| 
         @@ -1223,6 +1232,27 @@ def get_checkpoints_project(project_name, is_gradio=True): 
     | 
|
| 1223 | 
         
             
                return files_checkpoints, selelect_checkpoint
         
     | 
| 1224 | 
         | 
| 1225 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1226 | 
         
             
            def get_gpu_stats():
         
     | 
| 1227 | 
         
             
                gpu_stats = ""
         
     | 
| 1228 | 
         | 
| 
         @@ -1290,6 +1320,17 @@ def get_combined_stats(): 
     | 
|
| 1290 | 
         
             
                return combined_stats
         
     | 
| 1291 | 
         | 
| 1292 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1293 | 
         
             
            with gr.Blocks() as app:
         
     | 
| 1294 | 
         
             
                gr.Markdown(
         
     | 
| 1295 | 
         
             
                    """
         
     | 
| 
         @@ -1470,6 +1511,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle 
     | 
|
| 1470 | 
         | 
| 1471 | 
         
             
                        with gr.Row():
         
     | 
| 1472 | 
         
             
                            mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
         
     | 
| 
         | 
|
| 1473 | 
         
             
                            start_button = gr.Button("Start Training")
         
     | 
| 1474 | 
         
             
                            stop_button = gr.Button("Stop Training", interactive=False)
         
     | 
| 1475 | 
         | 
| 
         @@ -1491,6 +1533,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle 
     | 
|
| 1491 | 
         
             
                                tokenizer_typev,
         
     | 
| 1492 | 
         
             
                                tokenizer_filev,
         
     | 
| 1493 | 
         
             
                                mixed_precisionv,
         
     | 
| 
         | 
|
| 1494 | 
         
             
                            ) = load_settings(projects_selelect)
         
     | 
| 1495 | 
         
             
                            exp_name.value = exp_namev
         
     | 
| 1496 | 
         
             
                            learning_rate.value = learning_ratev
         
     | 
| 
         @@ -1508,9 +1551,43 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle 
     | 
|
| 1508 | 
         
             
                            tokenizer_type.value = tokenizer_typev
         
     | 
| 1509 | 
         
             
                            tokenizer_file.value = tokenizer_filev
         
     | 
| 1510 | 
         
             
                            mixed_precision.value = mixed_precisionv
         
     | 
| 
         | 
|
| 1511 | 
         | 
| 1512 | 
         
             
                        ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
         
     | 
| 1513 | 
         
             
                        txt_info_train = gr.Text(label="info", value="")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1514 | 
         
             
                        start_button.click(
         
     | 
| 1515 | 
         
             
                            fn=start_training,
         
     | 
| 1516 | 
         
             
                            inputs=[
         
     | 
| 
         @@ -1532,6 +1609,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle 
     | 
|
| 1532 | 
         
             
                                tokenizer_file,
         
     | 
| 1533 | 
         
             
                                mixed_precision,
         
     | 
| 1534 | 
         
             
                                ch_stream,
         
     | 
| 
         | 
|
| 1535 | 
         
             
                            ],
         
     | 
| 1536 | 
         
             
                            outputs=[txt_info_train, start_button, stop_button],
         
     | 
| 1537 | 
         
             
                        )
         
     | 
| 
         @@ -1583,6 +1661,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle 
     | 
|
| 1583 | 
         
             
                                tokenizer_type,
         
     | 
| 1584 | 
         
             
                                tokenizer_file,
         
     | 
| 1585 | 
         
             
                                mixed_precision,
         
     | 
| 
         | 
|
| 1586 | 
         
             
                            ]
         
     | 
| 1587 | 
         | 
| 1588 | 
         
             
                            return output_components
         
     | 
| 
         | 
|
| 69 | 
         
             
                tokenizer_type,
         
     | 
| 70 | 
         
             
                tokenizer_file,
         
     | 
| 71 | 
         
             
                mixed_precision,
         
     | 
| 72 | 
         
            +
                logger,
         
     | 
| 73 | 
         
             
            ):
         
     | 
| 74 | 
         
             
                path_project = os.path.join(path_project_ckpts, project_name)
         
     | 
| 75 | 
         
             
                os.makedirs(path_project, exist_ok=True)
         
     | 
| 
         | 
|
| 92 | 
         
             
                    "tokenizer_type": tokenizer_type,
         
     | 
| 93 | 
         
             
                    "tokenizer_file": tokenizer_file,
         
     | 
| 94 | 
         
             
                    "mixed_precision": mixed_precision,
         
     | 
| 95 | 
         
            +
                    "logger": logger,
         
     | 
| 96 | 
         
             
                }
         
     | 
| 97 | 
         
             
                with open(file_setting, "w") as f:
         
     | 
| 98 | 
         
             
                    json.dump(settings, f, indent=4)
         
     | 
| 
         | 
|
| 123 | 
         
             
                        "tokenizer_type": "pinyin",
         
     | 
| 124 | 
         
             
                        "tokenizer_file": "",
         
     | 
| 125 | 
         
             
                        "mixed_precision": "none",
         
     | 
| 126 | 
         
            +
                        "logger": "wandb",
         
     | 
| 127 | 
         
             
                    }
         
     | 
| 128 | 
         
             
                    return (
         
     | 
| 129 | 
         
             
                        settings["exp_name"],
         
     | 
| 
         | 
|
| 142 | 
         
             
                        settings["tokenizer_type"],
         
     | 
| 143 | 
         
             
                        settings["tokenizer_file"],
         
     | 
| 144 | 
         
             
                        settings["mixed_precision"],
         
     | 
| 145 | 
         
            +
                        settings["logger"],
         
     | 
| 146 | 
         
             
                    )
         
     | 
| 147 | 
         | 
| 148 | 
         
             
                with open(file_setting, "r") as f:
         
     | 
| 
         | 
|
| 164 | 
         
             
                    settings["tokenizer_type"],
         
     | 
| 165 | 
         
             
                    settings["tokenizer_file"],
         
     | 
| 166 | 
         
             
                    settings["mixed_precision"],
         
     | 
| 167 | 
         
            +
                    settings["logger"],
         
     | 
| 168 | 
         
             
                )
         
     | 
| 169 | 
         | 
| 170 | 
         | 
| 
         | 
|
| 379 | 
         
             
                tokenizer_file="",
         
     | 
| 380 | 
         
             
                mixed_precision="fp16",
         
     | 
| 381 | 
         
             
                stream=False,
         
     | 
| 382 | 
         
            +
                logger="wandb",
         
     | 
| 383 | 
         
             
            ):
         
     | 
| 384 | 
         
             
                global training_process, tts_api, stop_signal
         
     | 
| 385 | 
         | 
| 
         | 
|
| 453 | 
         | 
| 454 | 
         
             
                cmd += f" --tokenizer {tokenizer_type} "
         
     | 
| 455 | 
         | 
| 456 | 
         
            +
                cmd += f" --log_samples True --logger {logger} "
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
             
                print(cmd)
         
     | 
| 459 | 
         | 
| 460 | 
         
             
                save_settings(
         
     | 
| 
         | 
|
| 475 | 
         
             
                    tokenizer_type,
         
     | 
| 476 | 
         
             
                    tokenizer_file,
         
     | 
| 477 | 
         
             
                    mixed_precision,
         
     | 
| 478 | 
         
            +
                    logger,
         
     | 
| 479 | 
         
             
                )
         
     | 
| 480 | 
         | 
| 481 | 
         
             
                try:
         
     | 
| 
         | 
|
| 1232 | 
         
             
                return files_checkpoints, selelect_checkpoint
         
     | 
| 1233 | 
         | 
| 1234 | 
         | 
| 1235 | 
         
            +
            def get_audio_project(project_name, is_gradio=True):
         
     | 
| 1236 | 
         
            +
                if project_name is None:
         
     | 
| 1237 | 
         
            +
                    return [], ""
         
     | 
| 1238 | 
         
            +
                project_name = project_name.replace("_pinyin", "").replace("_char", "")
         
     | 
| 1239 | 
         
            +
             
     | 
| 1240 | 
         
            +
                if os.path.isdir(path_project_ckpts):
         
     | 
| 1241 | 
         
            +
                    files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
         
     | 
| 1242 | 
         
            +
                    files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
         
     | 
| 1243 | 
         
            +
             
     | 
| 1244 | 
         
            +
                    files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
         
     | 
| 1245 | 
         
            +
                else:
         
     | 
| 1246 | 
         
            +
                    files_audios = []
         
     | 
| 1247 | 
         
            +
             
     | 
| 1248 | 
         
            +
                selelect_checkpoint = None if not files_audios else files_audios[0]
         
     | 
| 1249 | 
         
            +
             
     | 
| 1250 | 
         
            +
                if is_gradio:
         
     | 
| 1251 | 
         
            +
                    return gr.update(choices=files_audios, value=selelect_checkpoint)
         
     | 
| 1252 | 
         
            +
             
     | 
| 1253 | 
         
            +
                return files_audios, selelect_checkpoint
         
     | 
| 1254 | 
         
            +
             
     | 
| 1255 | 
         
            +
             
     | 
| 1256 | 
         
             
            def get_gpu_stats():
         
     | 
| 1257 | 
         
             
                gpu_stats = ""
         
     | 
| 1258 | 
         | 
| 
         | 
|
| 1320 | 
         
             
                return combined_stats
         
     | 
| 1321 | 
         | 
| 1322 | 
         | 
| 1323 | 
         
            +
            def get_audio_select(file_sample):
         
     | 
| 1324 | 
         
            +
                select_audio_ref = file_sample
         
     | 
| 1325 | 
         
            +
                select_audio_gen = file_sample
         
     | 
| 1326 | 
         
            +
             
     | 
| 1327 | 
         
            +
                if file_sample is not None:
         
     | 
| 1328 | 
         
            +
                    select_audio_ref += "_ref.wav"
         
     | 
| 1329 | 
         
            +
                    select_audio_gen += "_gen.wav"
         
     | 
| 1330 | 
         
            +
             
     | 
| 1331 | 
         
            +
                return select_audio_ref, select_audio_gen
         
     | 
| 1332 | 
         
            +
             
     | 
| 1333 | 
         
            +
             
     | 
| 1334 | 
         
             
            with gr.Blocks() as app:
         
     | 
| 1335 | 
         
             
                gr.Markdown(
         
     | 
| 1336 | 
         
             
                    """
         
     | 
| 
         | 
|
| 1511 | 
         | 
| 1512 | 
         
             
                        with gr.Row():
         
     | 
| 1513 | 
         
             
                            mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
         
     | 
| 1514 | 
         
            +
                            cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
         
     | 
| 1515 | 
         
             
                            start_button = gr.Button("Start Training")
         
     | 
| 1516 | 
         
             
                            stop_button = gr.Button("Stop Training", interactive=False)
         
     | 
| 1517 | 
         | 
| 
         | 
|
| 1533 | 
         
             
                                tokenizer_typev,
         
     | 
| 1534 | 
         
             
                                tokenizer_filev,
         
     | 
| 1535 | 
         
             
                                mixed_precisionv,
         
     | 
| 1536 | 
         
            +
                                cd_loggerv,
         
     | 
| 1537 | 
         
             
                            ) = load_settings(projects_selelect)
         
     | 
| 1538 | 
         
             
                            exp_name.value = exp_namev
         
     | 
| 1539 | 
         
             
                            learning_rate.value = learning_ratev
         
     | 
| 
         | 
|
| 1551 | 
         
             
                            tokenizer_type.value = tokenizer_typev
         
     | 
| 1552 | 
         
             
                            tokenizer_file.value = tokenizer_filev
         
     | 
| 1553 | 
         
             
                            mixed_precision.value = mixed_precisionv
         
     | 
| 1554 | 
         
            +
                            cd_logger.value = cd_loggerv
         
     | 
| 1555 | 
         | 
| 1556 | 
         
             
                        ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
         
     | 
| 1557 | 
         
             
                        txt_info_train = gr.Text(label="info", value="")
         
     | 
| 1558 | 
         
            +
             
     | 
| 1559 | 
         
            +
                        list_audios, select_audio = get_audio_project(projects_selelect, False)
         
     | 
| 1560 | 
         
            +
             
     | 
| 1561 | 
         
            +
                        select_audio_ref = select_audio
         
     | 
| 1562 | 
         
            +
                        select_audio_gen = select_audio
         
     | 
| 1563 | 
         
            +
             
     | 
| 1564 | 
         
            +
                        if select_audio is not None:
         
     | 
| 1565 | 
         
            +
                            select_audio_ref += "_ref.wav"
         
     | 
| 1566 | 
         
            +
                            select_audio_gen += "_gen.wav"
         
     | 
| 1567 | 
         
            +
             
     | 
| 1568 | 
         
            +
                        with gr.Row():
         
     | 
| 1569 | 
         
            +
                            ch_list_audio = gr.Dropdown(
         
     | 
| 1570 | 
         
            +
                                choices=list_audios,
         
     | 
| 1571 | 
         
            +
                                value=select_audio,
         
     | 
| 1572 | 
         
            +
                                label="audios",
         
     | 
| 1573 | 
         
            +
                                allow_custom_value=True,
         
     | 
| 1574 | 
         
            +
                                scale=6,
         
     | 
| 1575 | 
         
            +
                                interactive=True,
         
     | 
| 1576 | 
         
            +
                            )
         
     | 
| 1577 | 
         
            +
                            bt_stream_audio = gr.Button("refresh", scale=1)
         
     | 
| 1578 | 
         
            +
                            bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
         
     | 
| 1579 | 
         
            +
                            cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
         
     | 
| 1580 | 
         
            +
             
     | 
| 1581 | 
         
            +
                        with gr.Row():
         
     | 
| 1582 | 
         
            +
                            audio_ref_stream = gr.Audio(label="original", type="filepath", value=select_audio_ref)
         
     | 
| 1583 | 
         
            +
                            audio_gen_stream = gr.Audio(label="generate", type="filepath", value=select_audio_gen)
         
     | 
| 1584 | 
         
            +
             
     | 
| 1585 | 
         
            +
                        ch_list_audio.change(
         
     | 
| 1586 | 
         
            +
                            fn=get_audio_select,
         
     | 
| 1587 | 
         
            +
                            inputs=[ch_list_audio],
         
     | 
| 1588 | 
         
            +
                            outputs=[audio_ref_stream, audio_gen_stream],
         
     | 
| 1589 | 
         
            +
                        )
         
     | 
| 1590 | 
         
            +
             
     | 
| 1591 | 
         
             
                        start_button.click(
         
     | 
| 1592 | 
         
             
                            fn=start_training,
         
     | 
| 1593 | 
         
             
                            inputs=[
         
     | 
| 
         | 
|
| 1609 | 
         
             
                                tokenizer_file,
         
     | 
| 1610 | 
         
             
                                mixed_precision,
         
     | 
| 1611 | 
         
             
                                ch_stream,
         
     | 
| 1612 | 
         
            +
                                cd_logger,
         
     | 
| 1613 | 
         
             
                            ],
         
     | 
| 1614 | 
         
             
                            outputs=[txt_info_train, start_button, stop_button],
         
     | 
| 1615 | 
         
             
                        )
         
     | 
| 
         | 
|
| 1661 | 
         
             
                                tokenizer_type,
         
     | 
| 1662 | 
         
             
                                tokenizer_file,
         
     | 
| 1663 | 
         
             
                                mixed_precision,
         
     | 
| 1664 | 
         
            +
                                cd_logger,
         
     | 
| 1665 | 
         
             
                            ]
         
     | 
| 1666 | 
         | 
| 1667 | 
         
             
                            return output_components
         
     | 
    	
        src/f5_tts/train/train.py
    CHANGED
    
    | 
         @@ -83,6 +83,7 @@ def main(): 
     | 
|
| 83 | 
         
             
                    wandb_run_name=exp_name,
         
     | 
| 84 | 
         
             
                    wandb_resume_id=wandb_resume_id,
         
     | 
| 85 | 
         
             
                    last_per_steps=last_per_steps,
         
     | 
| 
         | 
|
| 86 | 
         
             
                )
         
     | 
| 87 | 
         | 
| 88 | 
         
             
                train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
         
     | 
| 
         | 
|
| 83 | 
         
             
                    wandb_run_name=exp_name,
         
     | 
| 84 | 
         
             
                    wandb_resume_id=wandb_resume_id,
         
     | 
| 85 | 
         
             
                    last_per_steps=last_per_steps,
         
     | 
| 86 | 
         
            +
                    log_samples=True,
         
     | 
| 87 | 
         
             
                )
         
     | 
| 88 | 
         | 
| 89 | 
         
             
                train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
         
     |