update Bigvgan vocoder and F5-bigvgan version, trained on Emilia ZH&EN, 1.25m updates
Browse files- .gitmodules +3 -0
- README.md +12 -1
- src/f5_tts/api.py +9 -14
- src/f5_tts/eval/eval_infer_batch.py +36 -30
- src/f5_tts/eval/utils_eval.py +11 -3
- src/f5_tts/infer/infer_cli.py +18 -11
- src/f5_tts/infer/speech_edit.py +28 -27
- src/f5_tts/infer/utils_infer.py +53 -29
- src/f5_tts/model/cfm.py +9 -13
- src/f5_tts/model/dataset.py +23 -5
- src/f5_tts/model/modules.py +142 -30
- src/f5_tts/model/trainer.py +12 -12
- src/f5_tts/train/train.py +11 -5
- src/third_party/BigVGAN +1 -0
    	
        .gitmodules
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [submodule "src/third_party/BigVGAN"]
         | 
| 2 | 
            +
            	path = src/third_party/BigVGAN
         | 
| 3 | 
            +
            	url = https://github.com/NVIDIA/BigVGAN.git
         | 
    	
        README.md
    CHANGED
    
    | @@ -46,7 +46,18 @@ cd F5-TTS | |
| 46 | 
             
            pip install -e .
         | 
| 47 | 
             
            ```
         | 
| 48 |  | 
| 49 | 
            -
            ### 3.  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 50 | 
             
            ```bash
         | 
| 51 | 
             
            # Build from Dockerfile
         | 
| 52 | 
             
            docker build -t f5tts:v1 .
         | 
|  | |
| 46 | 
             
            pip install -e .
         | 
| 47 | 
             
            ```
         | 
| 48 |  | 
| 49 | 
            +
            ### 3. Init submodule( optional, if you want to change the vocoder from vocos to bigvgan)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ```bash
         | 
| 52 | 
            +
            git submodule update --init --recursive
         | 
| 53 | 
            +
            ```
         | 
| 54 | 
            +
            After that, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
         | 
| 55 | 
            +
            ```python
         | 
| 56 | 
            +
            import sys
         | 
| 57 | 
            +
            sys.path.append(os.path.dirname(os.path.abspath(__file__)))
         | 
| 58 | 
            +
            ```
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            ### 4. Docker usage
         | 
| 61 | 
             
            ```bash
         | 
| 62 | 
             
            # Build from Dockerfile
         | 
| 63 | 
             
            docker build -t f5tts:v1 .
         | 
    	
        src/f5_tts/api.py
    CHANGED
    
    | @@ -1,24 +1,18 @@ | |
| 1 | 
             
            import random
         | 
| 2 | 
             
            import sys
         | 
| 3 | 
            -
            import tqdm
         | 
| 4 | 
             
            from importlib.resources import files
         | 
| 5 |  | 
| 6 | 
             
            import soundfile as sf
         | 
| 7 | 
             
            import torch
         | 
|  | |
| 8 | 
             
            from cached_path import cached_path
         | 
| 9 |  | 
|  | |
|  | |
|  | |
|  | |
| 10 | 
             
            from f5_tts.model import DiT, UNetT
         | 
| 11 | 
             
            from f5_tts.model.utils import seed_everything
         | 
| 12 | 
            -
            from f5_tts.infer.utils_infer import (
         | 
| 13 | 
            -
                load_vocoder,
         | 
| 14 | 
            -
                load_model,
         | 
| 15 | 
            -
                infer_process,
         | 
| 16 | 
            -
                remove_silence_for_generated_wav,
         | 
| 17 | 
            -
                save_spectrogram,
         | 
| 18 | 
            -
                preprocess_ref_audio_text,
         | 
| 19 | 
            -
                target_sample_rate,
         | 
| 20 | 
            -
                hop_length,
         | 
| 21 | 
            -
            )
         | 
| 22 |  | 
| 23 |  | 
| 24 | 
             
            class F5TTS:
         | 
| @@ -29,6 +23,7 @@ class F5TTS: | |
| 29 | 
             
                    vocab_file="",
         | 
| 30 | 
             
                    ode_method="euler",
         | 
| 31 | 
             
                    use_ema=True,
         | 
|  | |
| 32 | 
             
                    local_path=None,
         | 
| 33 | 
             
                    device=None,
         | 
| 34 | 
             
                ):
         | 
| @@ -44,11 +39,11 @@ class F5TTS: | |
| 44 | 
             
                    )
         | 
| 45 |  | 
| 46 | 
             
                    # Load models
         | 
| 47 | 
            -
                    self.load_vocoder_model(local_path)
         | 
| 48 | 
             
                    self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
         | 
| 49 |  | 
| 50 | 
            -
                def load_vocoder_model(self, local_path):
         | 
| 51 | 
            -
                    self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
         | 
| 52 |  | 
| 53 | 
             
                def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
         | 
| 54 | 
             
                    if model_type == "F5-TTS":
         | 
|  | |
| 1 | 
             
            import random
         | 
| 2 | 
             
            import sys
         | 
|  | |
| 3 | 
             
            from importlib.resources import files
         | 
| 4 |  | 
| 5 | 
             
            import soundfile as sf
         | 
| 6 | 
             
            import torch
         | 
| 7 | 
            +
            import tqdm
         | 
| 8 | 
             
            from cached_path import cached_path
         | 
| 9 |  | 
| 10 | 
            +
            from f5_tts.infer.utils_infer import (hop_length, infer_process, load_model,
         | 
| 11 | 
            +
                                                  load_vocoder, preprocess_ref_audio_text,
         | 
| 12 | 
            +
                                                  remove_silence_for_generated_wav,
         | 
| 13 | 
            +
                                                  save_spectrogram, target_sample_rate)
         | 
| 14 | 
             
            from f5_tts.model import DiT, UNetT
         | 
| 15 | 
             
            from f5_tts.model.utils import seed_everything
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 16 |  | 
| 17 |  | 
| 18 | 
             
            class F5TTS:
         | 
|  | |
| 23 | 
             
                    vocab_file="",
         | 
| 24 | 
             
                    ode_method="euler",
         | 
| 25 | 
             
                    use_ema=True,
         | 
| 26 | 
            +
                    vocoder_name="vocos",
         | 
| 27 | 
             
                    local_path=None,
         | 
| 28 | 
             
                    device=None,
         | 
| 29 | 
             
                ):
         | 
|  | |
| 39 | 
             
                    )
         | 
| 40 |  | 
| 41 | 
             
                    # Load models
         | 
| 42 | 
            +
                    self.load_vocoder_model(vocoder_name, local_path)
         | 
| 43 | 
             
                    self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
         | 
| 44 |  | 
| 45 | 
            +
                def load_vocoder_model(self, vocoder_name, local_path):
         | 
| 46 | 
            +
                    self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
         | 
| 47 |  | 
| 48 | 
             
                def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
         | 
| 49 | 
             
                    if model_type == "F5-TTS":
         | 
    	
        src/f5_tts/eval/eval_infer_batch.py
    CHANGED
    
    | @@ -1,26 +1,23 @@ | |
| 1 | 
            -
            import sys
         | 
| 2 | 
             
            import os
         | 
|  | |
| 3 |  | 
| 4 | 
             
            sys.path.append(os.getcwd())
         | 
| 5 |  | 
| 6 | 
            -
            import time
         | 
| 7 | 
            -
            from tqdm import tqdm
         | 
| 8 | 
             
            import argparse
         | 
|  | |
| 9 | 
             
            from importlib.resources import files
         | 
| 10 |  | 
| 11 | 
             
            import torch
         | 
| 12 | 
             
            import torchaudio
         | 
| 13 | 
             
            from accelerate import Accelerator
         | 
| 14 | 
            -
            from  | 
| 15 |  | 
| 16 | 
            -
            from f5_tts. | 
|  | |
|  | |
|  | |
|  | |
| 17 | 
             
            from f5_tts.model.utils import get_tokenizer
         | 
| 18 | 
            -
            from f5_tts.infer.utils_infer import load_checkpoint
         | 
| 19 | 
            -
            from f5_tts.eval.utils_eval import (
         | 
| 20 | 
            -
                get_seedtts_testset_metainfo,
         | 
| 21 | 
            -
                get_librispeech_test_clean_metainfo,
         | 
| 22 | 
            -
                get_inference_prompt,
         | 
| 23 | 
            -
            )
         | 
| 24 |  | 
| 25 | 
             
            accelerator = Accelerator()
         | 
| 26 | 
             
            device = f"cuda:{accelerator.process_index}"
         | 
| @@ -31,8 +28,12 @@ device = f"cuda:{accelerator.process_index}" | |
| 31 | 
             
            target_sample_rate = 24000
         | 
| 32 | 
             
            n_mel_channels = 100
         | 
| 33 | 
             
            hop_length = 256
         | 
|  | |
|  | |
|  | |
| 34 | 
             
            target_rms = 0.1
         | 
| 35 |  | 
|  | |
| 36 | 
             
            tokenizer = "pinyin"
         | 
| 37 | 
             
            rel_path = str(files("f5_tts").joinpath("../../"))
         | 
| 38 |  | 
| @@ -123,14 +124,11 @@ def main(): | |
| 123 |  | 
| 124 | 
             
                # Vocoder model
         | 
| 125 | 
             
                local = False
         | 
| 126 | 
            -
                if  | 
| 127 | 
            -
                     | 
| 128 | 
            -
             | 
| 129 | 
            -
                     | 
| 130 | 
            -
             | 
| 131 | 
            -
                    vocos.eval()
         | 
| 132 | 
            -
                else:
         | 
| 133 | 
            -
                    vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         | 
| 134 |  | 
| 135 | 
             
                # Tokenizer
         | 
| 136 | 
             
                vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
         | 
| @@ -139,9 +137,12 @@ def main(): | |
| 139 | 
             
                model = CFM(
         | 
| 140 | 
             
                    transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 141 | 
             
                    mel_spec_kwargs=dict(
         | 
| 142 | 
            -
                         | 
| 143 | 
            -
                        n_mel_channels=n_mel_channels,
         | 
| 144 | 
             
                        hop_length=hop_length,
         | 
|  | |
|  | |
|  | |
|  | |
| 145 | 
             
                    ),
         | 
| 146 | 
             
                    odeint_kwargs=dict(
         | 
| 147 | 
             
                        method=ode_method,
         | 
| @@ -149,7 +150,8 @@ def main(): | |
| 149 | 
             
                    vocab_char_map=vocab_char_map,
         | 
| 150 | 
             
                ).to(device)
         | 
| 151 |  | 
| 152 | 
            -
                 | 
|  | |
| 153 |  | 
| 154 | 
             
                if not os.path.exists(output_dir) and accelerator.is_main_process:
         | 
| 155 | 
             
                    os.makedirs(output_dir)
         | 
| @@ -178,14 +180,18 @@ def main(): | |
| 178 | 
             
                                no_ref_audio=no_ref_audio,
         | 
| 179 | 
             
                                seed=seed,
         | 
| 180 | 
             
                            )
         | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
                                 | 
| 188 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 189 |  | 
| 190 | 
             
                accelerator.wait_for_everyone()
         | 
| 191 | 
             
                if accelerator.is_main_process:
         | 
|  | |
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 |  | 
| 4 | 
             
            sys.path.append(os.getcwd())
         | 
| 5 |  | 
|  | |
|  | |
| 6 | 
             
            import argparse
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
             
            from importlib.resources import files
         | 
| 9 |  | 
| 10 | 
             
            import torch
         | 
| 11 | 
             
            import torchaudio
         | 
| 12 | 
             
            from accelerate import Accelerator
         | 
| 13 | 
            +
            from tqdm import tqdm
         | 
| 14 |  | 
| 15 | 
            +
            from f5_tts.eval.utils_eval import (get_inference_prompt,
         | 
| 16 | 
            +
                                                get_librispeech_test_clean_metainfo,
         | 
| 17 | 
            +
                                                get_seedtts_testset_metainfo)
         | 
| 18 | 
            +
            from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
         | 
| 19 | 
            +
            from f5_tts.model import CFM, DiT, UNetT
         | 
| 20 | 
             
            from f5_tts.model.utils import get_tokenizer
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 21 |  | 
| 22 | 
             
            accelerator = Accelerator()
         | 
| 23 | 
             
            device = f"cuda:{accelerator.process_index}"
         | 
|  | |
| 28 | 
             
            target_sample_rate = 24000
         | 
| 29 | 
             
            n_mel_channels = 100
         | 
| 30 | 
             
            hop_length = 256
         | 
| 31 | 
            +
            win_length = 1024
         | 
| 32 | 
            +
            n_fft = 1024
         | 
| 33 | 
            +
            extract_backend = "bigvgan"  # 'vocos' or 'bigvgan'
         | 
| 34 | 
             
            target_rms = 0.1
         | 
| 35 |  | 
| 36 | 
            +
             | 
| 37 | 
             
            tokenizer = "pinyin"
         | 
| 38 | 
             
            rel_path = str(files("f5_tts").joinpath("../../"))
         | 
| 39 |  | 
|  | |
| 124 |  | 
| 125 | 
             
                # Vocoder model
         | 
| 126 | 
             
                local = False
         | 
| 127 | 
            +
                if extract_backend == "vocos":
         | 
| 128 | 
            +
                    vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
         | 
| 129 | 
            +
                elif extract_backend == "bigvgan":
         | 
| 130 | 
            +
                    vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
         | 
| 131 | 
            +
                vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path)
         | 
|  | |
|  | |
|  | |
| 132 |  | 
| 133 | 
             
                # Tokenizer
         | 
| 134 | 
             
                vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
         | 
|  | |
| 137 | 
             
                model = CFM(
         | 
| 138 | 
             
                    transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 139 | 
             
                    mel_spec_kwargs=dict(
         | 
| 140 | 
            +
                        n_fft=n_fft,
         | 
|  | |
| 141 | 
             
                        hop_length=hop_length,
         | 
| 142 | 
            +
                        win_length=win_length,
         | 
| 143 | 
            +
                        n_mel_channels=n_mel_channels,
         | 
| 144 | 
            +
                        target_sample_rate=target_sample_rate,
         | 
| 145 | 
            +
                        extract_backend=extract_backend,
         | 
| 146 | 
             
                    ),
         | 
| 147 | 
             
                    odeint_kwargs=dict(
         | 
| 148 | 
             
                        method=ode_method,
         | 
|  | |
| 150 | 
             
                    vocab_char_map=vocab_char_map,
         | 
| 151 | 
             
                ).to(device)
         | 
| 152 |  | 
| 153 | 
            +
                dtype = torch.float16 if extract_backend == "vocos" else torch.float32
         | 
| 154 | 
            +
                model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
         | 
| 155 |  | 
| 156 | 
             
                if not os.path.exists(output_dir) and accelerator.is_main_process:
         | 
| 157 | 
             
                    os.makedirs(output_dir)
         | 
|  | |
| 180 | 
             
                                no_ref_audio=no_ref_audio,
         | 
| 181 | 
             
                                seed=seed,
         | 
| 182 | 
             
                            )
         | 
| 183 | 
            +
                            # Final result
         | 
| 184 | 
            +
                            for i, gen in enumerate(generated):
         | 
| 185 | 
            +
                                gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
         | 
| 186 | 
            +
                                gen_mel_spec = gen.permute(0, 2, 1)
         | 
| 187 | 
            +
                                if extract_backend == "vocos":
         | 
| 188 | 
            +
                                    generated_wave = vocoder.decode(gen_mel_spec.cpu())
         | 
| 189 | 
            +
                                elif extract_backend == "bigvgan":
         | 
| 190 | 
            +
                                    generated_wave = vocoder(gen_mel_spec)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                                if ref_rms_list[i] < target_rms:
         | 
| 193 | 
            +
                                    generated_wave = generated_wave * ref_rms_list[i] / target_rms
         | 
| 194 | 
            +
                                torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
         | 
| 195 |  | 
| 196 | 
             
                accelerator.wait_for_everyone()
         | 
| 197 | 
             
                if accelerator.is_main_process:
         | 
    	
        src/f5_tts/eval/utils_eval.py
    CHANGED
    
    | @@ -2,15 +2,15 @@ import math | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            import random
         | 
| 4 | 
             
            import string
         | 
| 5 | 
            -
            from tqdm import tqdm
         | 
| 6 |  | 
| 7 | 
             
            import torch
         | 
| 8 | 
             
            import torch.nn.functional as F
         | 
| 9 | 
             
            import torchaudio
         | 
|  | |
| 10 |  | 
|  | |
| 11 | 
             
            from f5_tts.model.modules import MelSpec
         | 
| 12 | 
             
            from f5_tts.model.utils import convert_char_to_pinyin
         | 
| 13 | 
            -
            from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
         | 
| 14 |  | 
| 15 |  | 
| 16 | 
             
            # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
         | 
| @@ -74,8 +74,11 @@ def get_inference_prompt( | |
| 74 | 
             
                tokenizer="pinyin",
         | 
| 75 | 
             
                polyphone=True,
         | 
| 76 | 
             
                target_sample_rate=24000,
         | 
|  | |
|  | |
| 77 | 
             
                n_mel_channels=100,
         | 
| 78 | 
             
                hop_length=256,
         | 
|  | |
| 79 | 
             
                target_rms=0.1,
         | 
| 80 | 
             
                use_truth_duration=False,
         | 
| 81 | 
             
                infer_batch_size=1,
         | 
| @@ -94,7 +97,12 @@ def get_inference_prompt( | |
| 94 | 
             
                )
         | 
| 95 |  | 
| 96 | 
             
                mel_spectrogram = MelSpec(
         | 
| 97 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 98 | 
             
                )
         | 
| 99 |  | 
| 100 | 
             
                for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
         | 
|  | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            import random
         | 
| 4 | 
             
            import string
         | 
|  | |
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import torch.nn.functional as F
         | 
| 8 | 
             
            import torchaudio
         | 
| 9 | 
            +
            from tqdm import tqdm
         | 
| 10 |  | 
| 11 | 
            +
            from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
         | 
| 12 | 
             
            from f5_tts.model.modules import MelSpec
         | 
| 13 | 
             
            from f5_tts.model.utils import convert_char_to_pinyin
         | 
|  | |
| 14 |  | 
| 15 |  | 
| 16 | 
             
            # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
         | 
|  | |
| 74 | 
             
                tokenizer="pinyin",
         | 
| 75 | 
             
                polyphone=True,
         | 
| 76 | 
             
                target_sample_rate=24000,
         | 
| 77 | 
            +
                n_fft=1024,
         | 
| 78 | 
            +
                win_length=1024,
         | 
| 79 | 
             
                n_mel_channels=100,
         | 
| 80 | 
             
                hop_length=256,
         | 
| 81 | 
            +
                extract_backend="bigvgan",
         | 
| 82 | 
             
                target_rms=0.1,
         | 
| 83 | 
             
                use_truth_duration=False,
         | 
| 84 | 
             
                infer_batch_size=1,
         | 
|  | |
| 97 | 
             
                )
         | 
| 98 |  | 
| 99 | 
             
                mel_spectrogram = MelSpec(
         | 
| 100 | 
            +
                    n_fft=n_fft,
         | 
| 101 | 
            +
                    hop_length=hop_length,
         | 
| 102 | 
            +
                    win_length=win_length,
         | 
| 103 | 
            +
                    n_mel_channels=n_mel_channels,
         | 
| 104 | 
            +
                    target_sample_rate=target_sample_rate,
         | 
| 105 | 
            +
                    extract_backend=extract_backend,
         | 
| 106 | 
             
                )
         | 
| 107 |  | 
| 108 | 
             
                for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
         | 
    	
        src/f5_tts/infer/infer_cli.py
    CHANGED
    
    | @@ -2,23 +2,18 @@ import argparse | |
| 2 | 
             
            import codecs
         | 
| 3 | 
             
            import os
         | 
| 4 | 
             
            import re
         | 
| 5 | 
            -
            from pathlib import Path
         | 
| 6 | 
             
            from importlib.resources import files
         | 
|  | |
| 7 |  | 
| 8 | 
             
            import numpy as np
         | 
| 9 | 
             
            import soundfile as sf
         | 
| 10 | 
             
            import tomli
         | 
| 11 | 
             
            from cached_path import cached_path
         | 
| 12 |  | 
|  | |
|  | |
|  | |
| 13 | 
             
            from f5_tts.model import DiT, UNetT
         | 
| 14 | 
            -
            from f5_tts.infer.utils_infer import (
         | 
| 15 | 
            -
                load_vocoder,
         | 
| 16 | 
            -
                load_model,
         | 
| 17 | 
            -
                preprocess_ref_audio_text,
         | 
| 18 | 
            -
                infer_process,
         | 
| 19 | 
            -
                remove_silence_for_generated_wav,
         | 
| 20 | 
            -
            )
         | 
| 21 | 
            -
             | 
| 22 |  | 
| 23 | 
             
            parser = argparse.ArgumentParser(
         | 
| 24 | 
             
                prog="python3 infer-cli.py",
         | 
| @@ -70,6 +65,7 @@ parser.add_argument( | |
| 70 | 
             
                "--remove_silence",
         | 
| 71 | 
             
                help="Remove silence.",
         | 
| 72 | 
             
            )
         | 
|  | |
| 73 | 
             
            parser.add_argument(
         | 
| 74 | 
             
                "--load_vocoder_from_local",
         | 
| 75 | 
             
                action="store_true",
         | 
| @@ -111,9 +107,14 @@ remove_silence = args.remove_silence if args.remove_silence else config["remove_ | |
| 111 | 
             
            speed = args.speed
         | 
| 112 | 
             
            wave_path = Path(output_dir) / "infer_cli_out.wav"
         | 
| 113 | 
             
            # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
         | 
| 114 | 
            -
             | 
|  | |
|  | |
|  | |
| 115 |  | 
| 116 | 
            -
            vocoder = load_vocoder( | 
|  | |
|  | |
| 117 |  | 
| 118 |  | 
| 119 | 
             
            # load models
         | 
| @@ -136,6 +137,12 @@ elif model == "E2-TTS": | |
| 136 | 
             
                    ckpt_step = 1200000
         | 
| 137 | 
             
                    ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
         | 
| 138 | 
             
                    # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt"  # .pt | .safetensors; local path
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 139 |  | 
| 140 | 
             
            print(f"Using {model}...")
         | 
| 141 | 
             
            ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
         | 
|  | |
| 2 | 
             
            import codecs
         | 
| 3 | 
             
            import os
         | 
| 4 | 
             
            import re
         | 
|  | |
| 5 | 
             
            from importlib.resources import files
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 |  | 
| 8 | 
             
            import numpy as np
         | 
| 9 | 
             
            import soundfile as sf
         | 
| 10 | 
             
            import tomli
         | 
| 11 | 
             
            from cached_path import cached_path
         | 
| 12 |  | 
| 13 | 
            +
            from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
         | 
| 14 | 
            +
                                                  preprocess_ref_audio_text,
         | 
| 15 | 
            +
                                                  remove_silence_for_generated_wav)
         | 
| 16 | 
             
            from f5_tts.model import DiT, UNetT
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 17 |  | 
| 18 | 
             
            parser = argparse.ArgumentParser(
         | 
| 19 | 
             
                prog="python3 infer-cli.py",
         | 
|  | |
| 65 | 
             
                "--remove_silence",
         | 
| 66 | 
             
                help="Remove silence.",
         | 
| 67 | 
             
            )
         | 
| 68 | 
            +
            parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
         | 
| 69 | 
             
            parser.add_argument(
         | 
| 70 | 
             
                "--load_vocoder_from_local",
         | 
| 71 | 
             
                action="store_true",
         | 
|  | |
| 107 | 
             
            speed = args.speed
         | 
| 108 | 
             
            wave_path = Path(output_dir) / "infer_cli_out.wav"
         | 
| 109 | 
             
            # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
         | 
| 110 | 
            +
            if args.vocoder_name == "vocos":
         | 
| 111 | 
            +
                vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
         | 
| 112 | 
            +
            elif args.vocoder_name == "bigvgan":
         | 
| 113 | 
            +
                vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
         | 
| 114 |  | 
| 115 | 
            +
            vocoder = load_vocoder(
         | 
| 116 | 
            +
                vocoder_name=args.vocoder_name, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path
         | 
| 117 | 
            +
            )
         | 
| 118 |  | 
| 119 |  | 
| 120 | 
             
            # load models
         | 
|  | |
| 137 | 
             
                    ckpt_step = 1200000
         | 
| 138 | 
             
                    ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
         | 
| 139 | 
             
                    # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt"  # .pt | .safetensors; local path
         | 
| 140 | 
            +
                elif args.vocoder_name == "bigvgan":  # TODO: need to test
         | 
| 141 | 
            +
                    repo_name = "F5-TTS"
         | 
| 142 | 
            +
                    exp_name = "F5TTS_Base_bigvgan"
         | 
| 143 | 
            +
                    ckpt_step = 1250000
         | 
| 144 | 
            +
                    ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
         | 
| 145 | 
            +
             | 
| 146 |  | 
| 147 | 
             
            print(f"Using {model}...")
         | 
| 148 | 
             
            ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
         | 
    	
        src/f5_tts/infer/speech_edit.py
    CHANGED
    
    | @@ -3,17 +3,11 @@ import os | |
| 3 | 
             
            import torch
         | 
| 4 | 
             
            import torch.nn.functional as F
         | 
| 5 | 
             
            import torchaudio
         | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
            from f5_tts.model | 
| 10 | 
            -
             | 
| 11 | 
            -
                convert_char_to_pinyin,
         | 
| 12 | 
            -
            )
         | 
| 13 | 
            -
            from f5_tts.infer.utils_infer import (
         | 
| 14 | 
            -
                load_checkpoint,
         | 
| 15 | 
            -
                save_spectrogram,
         | 
| 16 | 
            -
            )
         | 
| 17 |  | 
| 18 | 
             
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 19 |  | 
| @@ -23,6 +17,9 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is | |
| 23 | 
             
            target_sample_rate = 24000
         | 
| 24 | 
             
            n_mel_channels = 100
         | 
| 25 | 
             
            hop_length = 256
         | 
|  | |
|  | |
|  | |
| 26 | 
             
            target_rms = 0.1
         | 
| 27 |  | 
| 28 | 
             
            tokenizer = "pinyin"
         | 
| @@ -89,15 +86,11 @@ if not os.path.exists(output_dir): | |
| 89 |  | 
| 90 | 
             
            # Vocoder model
         | 
| 91 | 
             
            local = False
         | 
| 92 | 
            -
            if  | 
| 93 | 
            -
                 | 
| 94 | 
            -
             | 
| 95 | 
            -
                 | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
                vocos.eval()
         | 
| 99 | 
            -
            else:
         | 
| 100 | 
            -
                vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         | 
| 101 |  | 
| 102 | 
             
            # Tokenizer
         | 
| 103 | 
             
            vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
         | 
| @@ -106,9 +99,12 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) | |
| 106 | 
             
            model = CFM(
         | 
| 107 | 
             
                transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 108 | 
             
                mel_spec_kwargs=dict(
         | 
| 109 | 
            -
                     | 
| 110 | 
            -
                    n_mel_channels=n_mel_channels,
         | 
| 111 | 
             
                    hop_length=hop_length,
         | 
|  | |
|  | |
|  | |
|  | |
| 112 | 
             
                ),
         | 
| 113 | 
             
                odeint_kwargs=dict(
         | 
| 114 | 
             
                    method=ode_method,
         | 
| @@ -116,7 +112,8 @@ model = CFM( | |
| 116 | 
             
                vocab_char_map=vocab_char_map,
         | 
| 117 | 
             
            ).to(device)
         | 
| 118 |  | 
| 119 | 
            -
             | 
|  | |
| 120 |  | 
| 121 | 
             
            # Audio
         | 
| 122 | 
             
            audio, sr = torchaudio.load(audio_to_edit)
         | 
| @@ -181,11 +178,15 @@ print(f"Generated mel: {generated.shape}") | |
| 181 | 
             
            # Final result
         | 
| 182 | 
             
            generated = generated.to(torch.float32)
         | 
| 183 | 
             
            generated = generated[:, ref_audio_len:, :]
         | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 186 | 
             
            if rms < target_rms:
         | 
| 187 | 
             
                generated_wave = generated_wave * rms / target_rms
         | 
| 188 |  | 
| 189 | 
            -
            save_spectrogram( | 
| 190 | 
            -
            torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
         | 
| 191 | 
             
            print(f"Generated wav: {generated_wave.shape}")
         | 
|  | |
| 3 | 
             
            import torch
         | 
| 4 | 
             
            import torch.nn.functional as F
         | 
| 5 | 
             
            import torchaudio
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder,
         | 
| 8 | 
            +
                                                  save_spectrogram)
         | 
| 9 | 
            +
            from f5_tts.model import CFM, DiT, UNetT
         | 
| 10 | 
            +
            from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 |  | 
| 12 | 
             
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 13 |  | 
|  | |
| 17 | 
             
            target_sample_rate = 24000
         | 
| 18 | 
             
            n_mel_channels = 100
         | 
| 19 | 
             
            hop_length = 256
         | 
| 20 | 
            +
            win_length = 1024
         | 
| 21 | 
            +
            n_fft = 1024
         | 
| 22 | 
            +
            extract_backend = "bigvgan"  # 'vocos' or 'bigvgan'
         | 
| 23 | 
             
            target_rms = 0.1
         | 
| 24 |  | 
| 25 | 
             
            tokenizer = "pinyin"
         | 
|  | |
| 86 |  | 
| 87 | 
             
            # Vocoder model
         | 
| 88 | 
             
            local = False
         | 
| 89 | 
            +
            if extract_backend == "vocos":
         | 
| 90 | 
            +
                vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
         | 
| 91 | 
            +
            elif extract_backend == "bigvgan":
         | 
| 92 | 
            +
                vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
         | 
| 93 | 
            +
            vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path)
         | 
|  | |
|  | |
|  | |
|  | |
| 94 |  | 
| 95 | 
             
            # Tokenizer
         | 
| 96 | 
             
            vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
         | 
|  | |
| 99 | 
             
            model = CFM(
         | 
| 100 | 
             
                transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 101 | 
             
                mel_spec_kwargs=dict(
         | 
| 102 | 
            +
                    n_fft=n_fft,
         | 
|  | |
| 103 | 
             
                    hop_length=hop_length,
         | 
| 104 | 
            +
                    win_length=win_length,
         | 
| 105 | 
            +
                    n_mel_channels=n_mel_channels,
         | 
| 106 | 
            +
                    target_sample_rate=target_sample_rate,
         | 
| 107 | 
            +
                    extract_backend=extract_backend,
         | 
| 108 | 
             
                ),
         | 
| 109 | 
             
                odeint_kwargs=dict(
         | 
| 110 | 
             
                    method=ode_method,
         | 
|  | |
| 112 | 
             
                vocab_char_map=vocab_char_map,
         | 
| 113 | 
             
            ).to(device)
         | 
| 114 |  | 
| 115 | 
            +
            dtype = torch.float16 if extract_backend == "vocos" else torch.float32
         | 
| 116 | 
            +
            model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
         | 
| 117 |  | 
| 118 | 
             
            # Audio
         | 
| 119 | 
             
            audio, sr = torchaudio.load(audio_to_edit)
         | 
|  | |
| 178 | 
             
            # Final result
         | 
| 179 | 
             
            generated = generated.to(torch.float32)
         | 
| 180 | 
             
            generated = generated[:, ref_audio_len:, :]
         | 
| 181 | 
            +
            gen_mel_spec = generated.permute(0, 2, 1)
         | 
| 182 | 
            +
            if extract_backend == "vocos":
         | 
| 183 | 
            +
                generated_wave = vocoder.decode(gen_mel_spec.cpu())
         | 
| 184 | 
            +
            elif extract_backend == "bigvgan":
         | 
| 185 | 
            +
                generated_wave = vocoder(gen_mel_spec)
         | 
| 186 | 
            +
             | 
| 187 | 
             
            if rms < target_rms:
         | 
| 188 | 
             
                generated_wave = generated_wave * rms / target_rms
         | 
| 189 |  | 
| 190 | 
            +
            save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
         | 
| 191 | 
            +
            torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
         | 
| 192 | 
             
            print(f"Generated wav: {generated_wave.shape}")
         | 
    	
        src/f5_tts/infer/utils_infer.py
    CHANGED
    
    | @@ -1,6 +1,10 @@ | |
| 1 | 
             
            # A unified script for inference process
         | 
| 2 | 
             
            # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
         | 
|  | |
|  | |
| 3 |  | 
|  | |
|  | |
| 4 | 
             
            import hashlib
         | 
| 5 | 
             
            import re
         | 
| 6 | 
             
            import tempfile
         | 
| @@ -34,6 +38,9 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is | |
| 34 | 
             
            target_sample_rate = 24000
         | 
| 35 | 
             
            n_mel_channels = 100
         | 
| 36 | 
             
            hop_length = 256
         | 
|  | |
|  | |
|  | |
| 37 | 
             
            target_rms = 0.1
         | 
| 38 | 
             
            cross_fade_duration = 0.15
         | 
| 39 | 
             
            ode_method = "euler"
         | 
| @@ -80,17 +87,28 @@ def chunk_text(text, max_chars=135): | |
| 80 |  | 
| 81 |  | 
| 82 | 
             
            # load vocoder
         | 
| 83 | 
            -
            def load_vocoder(is_local=False, local_path="", device=device):
         | 
| 84 | 
            -
                if  | 
| 85 | 
            -
                     | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
                     | 
| 93 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 94 |  | 
| 95 |  | 
| 96 | 
             
            # load asr pipeline
         | 
| @@ -111,9 +129,8 @@ def initialize_asr_pipeline(device=device): | |
| 111 | 
             
            # load model checkpoint for inference
         | 
| 112 |  | 
| 113 |  | 
| 114 | 
            -
            def load_checkpoint(model, ckpt_path, device, use_ema=True):
         | 
| 115 | 
            -
                 | 
| 116 | 
            -
                    model = model.half()
         | 
| 117 |  | 
| 118 | 
             
                ckpt_type = ckpt_path.split(".")[-1]
         | 
| 119 | 
             
                if ckpt_type == "safetensors":
         | 
| @@ -156,9 +173,12 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me | |
| 156 | 
             
                model = CFM(
         | 
| 157 | 
             
                    transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 158 | 
             
                    mel_spec_kwargs=dict(
         | 
| 159 | 
            -
                         | 
| 160 | 
            -
                        n_mel_channels=n_mel_channels,
         | 
| 161 | 
             
                        hop_length=hop_length,
         | 
|  | |
|  | |
|  | |
|  | |
| 162 | 
             
                    ),
         | 
| 163 | 
             
                    odeint_kwargs=dict(
         | 
| 164 | 
             
                        method=ode_method,
         | 
| @@ -166,7 +186,8 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me | |
| 166 | 
             
                    vocab_char_map=vocab_char_map,
         | 
| 167 | 
             
                ).to(device)
         | 
| 168 |  | 
| 169 | 
            -
                 | 
|  | |
| 170 |  | 
| 171 | 
             
                return model
         | 
| 172 |  | 
| @@ -359,18 +380,21 @@ def infer_batch_process( | |
| 359 | 
             
                            sway_sampling_coef=sway_sampling_coef,
         | 
| 360 | 
             
                        )
         | 
| 361 |  | 
| 362 | 
            -
             | 
| 363 | 
            -
             | 
| 364 | 
            -
             | 
| 365 | 
            -
             | 
| 366 | 
            -
             | 
| 367 | 
            -
                         | 
| 368 | 
            -
             | 
| 369 | 
            -
             | 
| 370 | 
            -
             | 
| 371 | 
            -
             | 
| 372 | 
            -
             | 
| 373 | 
            -
             | 
|  | |
|  | |
|  | |
| 374 |  | 
| 375 | 
             
                # Combine all generated waves with cross-fading
         | 
| 376 | 
             
                if cross_fade_duration <= 0:
         | 
|  | |
| 1 | 
             
            # A unified script for inference process
         | 
| 2 | 
             
            # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import sys
         | 
| 5 |  | 
| 6 | 
            +
            sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
         | 
| 7 | 
            +
            from third_party.BigVGAN import bigvgan
         | 
| 8 | 
             
            import hashlib
         | 
| 9 | 
             
            import re
         | 
| 10 | 
             
            import tempfile
         | 
|  | |
| 38 | 
             
            target_sample_rate = 24000
         | 
| 39 | 
             
            n_mel_channels = 100
         | 
| 40 | 
             
            hop_length = 256
         | 
| 41 | 
            +
            win_length = 1024
         | 
| 42 | 
            +
            n_fft = 1024
         | 
| 43 | 
            +
            extract_backend = "bigvgan"  # 'vocos' or 'bigvgan'
         | 
| 44 | 
             
            target_rms = 0.1
         | 
| 45 | 
             
            cross_fade_duration = 0.15
         | 
| 46 | 
             
            ode_method = "euler"
         | 
|  | |
| 87 |  | 
| 88 |  | 
| 89 | 
             
            # load vocoder
         | 
| 90 | 
            +
            def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device):
         | 
| 91 | 
            +
                if vocoder_name == "vocos":
         | 
| 92 | 
            +
                    if is_local:
         | 
| 93 | 
            +
                        print(f"Load vocos from local path {local_path}")
         | 
| 94 | 
            +
                        vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
         | 
| 95 | 
            +
                        state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
         | 
| 96 | 
            +
                        vocoder.load_state_dict(state_dict)
         | 
| 97 | 
            +
                        vocoder.eval()
         | 
| 98 | 
            +
                        vocoder = vocoder.eval().to(device)
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        print("Download Vocos from huggingface charactr/vocos-mel-24khz")
         | 
| 101 | 
            +
                        vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         | 
| 102 | 
            +
                elif vocoder_name == "bigvgan":
         | 
| 103 | 
            +
                    if is_local:
         | 
| 104 | 
            +
                        """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
         | 
| 105 | 
            +
                        vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
         | 
| 106 | 
            +
                    else:
         | 
| 107 | 
            +
                        vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    vocoder.remove_weight_norm()
         | 
| 110 | 
            +
                    vocoder = vocoder.eval().to(device)
         | 
| 111 | 
            +
                return vocoder
         | 
| 112 |  | 
| 113 |  | 
| 114 | 
             
            # load asr pipeline
         | 
|  | |
| 129 | 
             
            # load model checkpoint for inference
         | 
| 130 |  | 
| 131 |  | 
| 132 | 
            +
            def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
         | 
| 133 | 
            +
                model = model.to(dtype)
         | 
|  | |
| 134 |  | 
| 135 | 
             
                ckpt_type = ckpt_path.split(".")[-1]
         | 
| 136 | 
             
                if ckpt_type == "safetensors":
         | 
|  | |
| 173 | 
             
                model = CFM(
         | 
| 174 | 
             
                    transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 175 | 
             
                    mel_spec_kwargs=dict(
         | 
| 176 | 
            +
                        n_fft=n_fft,
         | 
|  | |
| 177 | 
             
                        hop_length=hop_length,
         | 
| 178 | 
            +
                        win_length=win_length,
         | 
| 179 | 
            +
                        n_mel_channels=n_mel_channels,
         | 
| 180 | 
            +
                        target_sample_rate=target_sample_rate,
         | 
| 181 | 
            +
                        extract_backend=extract_backend,
         | 
| 182 | 
             
                    ),
         | 
| 183 | 
             
                    odeint_kwargs=dict(
         | 
| 184 | 
             
                        method=ode_method,
         | 
|  | |
| 186 | 
             
                    vocab_char_map=vocab_char_map,
         | 
| 187 | 
             
                ).to(device)
         | 
| 188 |  | 
| 189 | 
            +
                dtype = torch.float16 if extract_backend == "vocos" else torch.float32
         | 
| 190 | 
            +
                model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
         | 
| 191 |  | 
| 192 | 
             
                return model
         | 
| 193 |  | 
|  | |
| 380 | 
             
                            sway_sampling_coef=sway_sampling_coef,
         | 
| 381 | 
             
                        )
         | 
| 382 |  | 
| 383 | 
            +
                        generated = generated.to(torch.float32)
         | 
| 384 | 
            +
                        generated = generated[:, ref_audio_len:, :]
         | 
| 385 | 
            +
                        generated_mel_spec = generated.permute(0, 2, 1)
         | 
| 386 | 
            +
                        if extract_backend == "vocos":
         | 
| 387 | 
            +
                            generated_wave = vocoder.decode(generated_mel_spec.cpu())
         | 
| 388 | 
            +
                        elif extract_backend == "bigvgan":
         | 
| 389 | 
            +
                            generated_wave = vocoder(generated_mel_spec)
         | 
| 390 | 
            +
                        if rms < target_rms:
         | 
| 391 | 
            +
                            generated_wave = generated_wave * rms / target_rms
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                        # wav -> numpy
         | 
| 394 | 
            +
                        generated_wave = generated_wave.squeeze().cpu().numpy()
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                        generated_waves.append(generated_wave)
         | 
| 397 | 
            +
                        spectrograms.append(generated_mel_spec[0].cpu().numpy())
         | 
| 398 |  | 
| 399 | 
             
                # Combine all generated waves with cross-fading
         | 
| 400 | 
             
                if cross_fade_duration <= 0:
         | 
    	
        src/f5_tts/model/cfm.py
    CHANGED
    
    | @@ -8,25 +8,19 @@ d - dimension | |
| 8 | 
             
            """
         | 
| 9 |  | 
| 10 | 
             
            from __future__ import annotations
         | 
| 11 | 
            -
             | 
| 12 | 
             
            from random import random
         | 
|  | |
| 13 |  | 
| 14 | 
             
            import torch
         | 
| 15 | 
            -
            from torch import nn
         | 
| 16 | 
             
            import torch.nn.functional as F
         | 
|  | |
| 17 | 
             
            from torch.nn.utils.rnn import pad_sequence
         | 
| 18 | 
            -
             | 
| 19 | 
             
            from torchdiffeq import odeint
         | 
| 20 |  | 
| 21 | 
             
            from f5_tts.model.modules import MelSpec
         | 
| 22 | 
            -
            from f5_tts.model.utils import (
         | 
| 23 | 
            -
             | 
| 24 | 
            -
                exists,
         | 
| 25 | 
            -
                list_str_to_idx,
         | 
| 26 | 
            -
                list_str_to_tensor,
         | 
| 27 | 
            -
                lens_to_mask,
         | 
| 28 | 
            -
                mask_from_frac_lengths,
         | 
| 29 | 
            -
            )
         | 
| 30 |  | 
| 31 |  | 
| 32 | 
             
            class CFM(nn.Module):
         | 
| @@ -99,8 +93,10 @@ class CFM(nn.Module): | |
| 99 | 
             
                ):
         | 
| 100 | 
             
                    self.eval()
         | 
| 101 |  | 
| 102 | 
            -
                     | 
| 103 | 
            -
                         | 
|  | |
|  | |
| 104 |  | 
| 105 | 
             
                    # raw wave
         | 
| 106 |  | 
|  | |
| 8 | 
             
            """
         | 
| 9 |  | 
| 10 | 
             
            from __future__ import annotations
         | 
| 11 | 
            +
             | 
| 12 | 
             
            from random import random
         | 
| 13 | 
            +
            from typing import Callable
         | 
| 14 |  | 
| 15 | 
             
            import torch
         | 
|  | |
| 16 | 
             
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from torch import nn
         | 
| 18 | 
             
            from torch.nn.utils.rnn import pad_sequence
         | 
|  | |
| 19 | 
             
            from torchdiffeq import odeint
         | 
| 20 |  | 
| 21 | 
             
            from f5_tts.model.modules import MelSpec
         | 
| 22 | 
            +
            from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
         | 
| 23 | 
            +
                                            list_str_to_tensor, mask_from_frac_lengths)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 24 |  | 
| 25 |  | 
| 26 | 
             
            class CFM(nn.Module):
         | 
|  | |
| 93 | 
             
                ):
         | 
| 94 | 
             
                    self.eval()
         | 
| 95 |  | 
| 96 | 
            +
                    assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print(
         | 
| 97 | 
            +
                        "Only support fp16 and fp32 inference currently"
         | 
| 98 | 
            +
                    )
         | 
| 99 | 
            +
                    cond = cond.to(next(self.parameters()).dtype)
         | 
| 100 |  | 
| 101 | 
             
                    # raw wave
         | 
| 102 |  | 
    	
        src/f5_tts/model/dataset.py
    CHANGED
    
    | @@ -1,15 +1,15 @@ | |
| 1 | 
             
            import json
         | 
| 2 | 
             
            import random
         | 
| 3 | 
             
            from importlib.resources import files
         | 
| 4 | 
            -
            from tqdm import tqdm
         | 
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import torch.nn.functional as F
         | 
| 8 | 
             
            import torchaudio
         | 
|  | |
|  | |
| 9 | 
             
            from torch import nn
         | 
| 10 | 
             
            from torch.utils.data import Dataset, Sampler
         | 
| 11 | 
            -
            from  | 
| 12 | 
            -
            from datasets import Dataset as Dataset_
         | 
| 13 |  | 
| 14 | 
             
            from f5_tts.model.modules import MelSpec
         | 
| 15 | 
             
            from f5_tts.model.utils import default
         | 
| @@ -22,12 +22,21 @@ class HFDataset(Dataset): | |
| 22 | 
             
                    target_sample_rate=24_000,
         | 
| 23 | 
             
                    n_mel_channels=100,
         | 
| 24 | 
             
                    hop_length=256,
         | 
|  | |
|  | |
|  | |
| 25 | 
             
                ):
         | 
| 26 | 
             
                    self.data = hf_dataset
         | 
| 27 | 
             
                    self.target_sample_rate = target_sample_rate
         | 
| 28 | 
             
                    self.hop_length = hop_length
         | 
|  | |
| 29 | 
             
                    self.mel_spectrogram = MelSpec(
         | 
| 30 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 | 
             
                    )
         | 
| 32 |  | 
| 33 | 
             
                def get_frame_len(self, index):
         | 
| @@ -79,6 +88,9 @@ class CustomDataset(Dataset): | |
| 79 | 
             
                    target_sample_rate=24_000,
         | 
| 80 | 
             
                    hop_length=256,
         | 
| 81 | 
             
                    n_mel_channels=100,
         | 
|  | |
|  | |
|  | |
| 82 | 
             
                    preprocessed_mel=False,
         | 
| 83 | 
             
                    mel_spec_module: nn.Module | None = None,
         | 
| 84 | 
             
                ):
         | 
| @@ -86,15 +98,21 @@ class CustomDataset(Dataset): | |
| 86 | 
             
                    self.durations = durations
         | 
| 87 | 
             
                    self.target_sample_rate = target_sample_rate
         | 
| 88 | 
             
                    self.hop_length = hop_length
         | 
|  | |
|  | |
|  | |
| 89 | 
             
                    self.preprocessed_mel = preprocessed_mel
         | 
| 90 |  | 
| 91 | 
             
                    if not preprocessed_mel:
         | 
| 92 | 
             
                        self.mel_spectrogram = default(
         | 
| 93 | 
             
                            mel_spec_module,
         | 
| 94 | 
             
                            MelSpec(
         | 
| 95 | 
            -
                                 | 
| 96 | 
             
                                hop_length=hop_length,
         | 
|  | |
| 97 | 
             
                                n_mel_channels=n_mel_channels,
         | 
|  | |
|  | |
| 98 | 
             
                            ),
         | 
| 99 | 
             
                        )
         | 
| 100 |  | 
|  | |
| 1 | 
             
            import json
         | 
| 2 | 
             
            import random
         | 
| 3 | 
             
            from importlib.resources import files
         | 
|  | |
| 4 |  | 
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import torch.nn.functional as F
         | 
| 7 | 
             
            import torchaudio
         | 
| 8 | 
            +
            from datasets import Dataset as Dataset_
         | 
| 9 | 
            +
            from datasets import load_from_disk
         | 
| 10 | 
             
            from torch import nn
         | 
| 11 | 
             
            from torch.utils.data import Dataset, Sampler
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
|  | |
| 13 |  | 
| 14 | 
             
            from f5_tts.model.modules import MelSpec
         | 
| 15 | 
             
            from f5_tts.model.utils import default
         | 
|  | |
| 22 | 
             
                    target_sample_rate=24_000,
         | 
| 23 | 
             
                    n_mel_channels=100,
         | 
| 24 | 
             
                    hop_length=256,
         | 
| 25 | 
            +
                    n_fft=1024,
         | 
| 26 | 
            +
                    win_length=1024,
         | 
| 27 | 
            +
                    extract_backend="vocos",
         | 
| 28 | 
             
                ):
         | 
| 29 | 
             
                    self.data = hf_dataset
         | 
| 30 | 
             
                    self.target_sample_rate = target_sample_rate
         | 
| 31 | 
             
                    self.hop_length = hop_length
         | 
| 32 | 
            +
             | 
| 33 | 
             
                    self.mel_spectrogram = MelSpec(
         | 
| 34 | 
            +
                        n_fft=n_fft,
         | 
| 35 | 
            +
                        hop_length=hop_length,
         | 
| 36 | 
            +
                        win_length=win_length,
         | 
| 37 | 
            +
                        n_mel_channels=n_mel_channels,
         | 
| 38 | 
            +
                        target_sample_rate=target_sample_rate,
         | 
| 39 | 
            +
                        extract_backend=extract_backend,
         | 
| 40 | 
             
                    )
         | 
| 41 |  | 
| 42 | 
             
                def get_frame_len(self, index):
         | 
|  | |
| 88 | 
             
                    target_sample_rate=24_000,
         | 
| 89 | 
             
                    hop_length=256,
         | 
| 90 | 
             
                    n_mel_channels=100,
         | 
| 91 | 
            +
                    n_fft=1024,
         | 
| 92 | 
            +
                    win_length=1024,
         | 
| 93 | 
            +
                    extract_backend="vocos",
         | 
| 94 | 
             
                    preprocessed_mel=False,
         | 
| 95 | 
             
                    mel_spec_module: nn.Module | None = None,
         | 
| 96 | 
             
                ):
         | 
|  | |
| 98 | 
             
                    self.durations = durations
         | 
| 99 | 
             
                    self.target_sample_rate = target_sample_rate
         | 
| 100 | 
             
                    self.hop_length = hop_length
         | 
| 101 | 
            +
                    self.n_fft = n_fft
         | 
| 102 | 
            +
                    self.win_length = win_length
         | 
| 103 | 
            +
                    self.extract_backend = extract_backend
         | 
| 104 | 
             
                    self.preprocessed_mel = preprocessed_mel
         | 
| 105 |  | 
| 106 | 
             
                    if not preprocessed_mel:
         | 
| 107 | 
             
                        self.mel_spectrogram = default(
         | 
| 108 | 
             
                            mel_spec_module,
         | 
| 109 | 
             
                            MelSpec(
         | 
| 110 | 
            +
                                n_fft=n_fft,
         | 
| 111 | 
             
                                hop_length=hop_length,
         | 
| 112 | 
            +
                                win_length=win_length,
         | 
| 113 | 
             
                                n_mel_channels=n_mel_channels,
         | 
| 114 | 
            +
                                target_sample_rate=target_sample_rate,
         | 
| 115 | 
            +
                                extract_backend=extract_backend,
         | 
| 116 | 
             
                            ),
         | 
| 117 | 
             
                        )
         | 
| 118 |  | 
    	
        src/f5_tts/model/modules.py
    CHANGED
    
    | @@ -8,61 +8,173 @@ d - dimension | |
| 8 | 
             
            """
         | 
| 9 |  | 
| 10 | 
             
            from __future__ import annotations
         | 
| 11 | 
            -
             | 
| 12 | 
             
            import math
         | 
|  | |
| 13 |  | 
| 14 | 
             
            import torch
         | 
| 15 | 
            -
            from torch import nn
         | 
| 16 | 
             
            import torch.nn.functional as F
         | 
| 17 | 
             
            import torchaudio
         | 
| 18 | 
            -
             | 
|  | |
| 19 | 
             
            from x_transformers.x_transformers import apply_rotary_pos_emb
         | 
| 20 |  | 
| 21 | 
            -
             | 
| 22 | 
             
            # raw wav to mel spec
         | 
| 23 |  | 
| 24 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 25 | 
             
            class MelSpec(nn.Module):
         | 
| 26 | 
             
                def __init__(
         | 
| 27 | 
             
                    self,
         | 
| 28 | 
            -
                     | 
| 29 | 
             
                    hop_length=256,
         | 
| 30 | 
             
                    win_length=1024,
         | 
| 31 | 
             
                    n_mel_channels=100,
         | 
| 32 | 
             
                    target_sample_rate=24_000,
         | 
| 33 | 
            -
                     | 
| 34 | 
            -
                    power=1,
         | 
| 35 | 
            -
                    norm=None,
         | 
| 36 | 
            -
                    center=True,
         | 
| 37 | 
             
                ):
         | 
| 38 | 
             
                    super().__init__()
         | 
| 39 | 
            -
                     | 
| 40 | 
            -
             | 
| 41 | 
            -
                    self.mel_stft = torchaudio.transforms.MelSpectrogram(
         | 
| 42 | 
            -
                        sample_rate=target_sample_rate,
         | 
| 43 | 
            -
                        n_fft=filter_length,
         | 
| 44 | 
            -
                        win_length=win_length,
         | 
| 45 | 
            -
                        hop_length=hop_length,
         | 
| 46 | 
            -
                        n_mels=n_mel_channels,
         | 
| 47 | 
            -
                        power=power,
         | 
| 48 | 
            -
                        center=center,
         | 
| 49 | 
            -
                        normalized=normalize,
         | 
| 50 | 
            -
                        norm=norm,
         | 
| 51 | 
             
                    )
         | 
| 52 |  | 
| 53 | 
            -
                    self. | 
|  | |
|  | |
|  | |
|  | |
| 54 |  | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
|  | |
| 58 |  | 
| 59 | 
            -
                     | 
| 60 |  | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 63 |  | 
| 64 | 
            -
                    mel = self.mel_stft(inp)
         | 
| 65 | 
            -
                    mel = mel.clamp(min=1e-5).log()
         | 
| 66 | 
             
                    return mel
         | 
| 67 |  | 
| 68 |  | 
|  | |
| 8 | 
             
            """
         | 
| 9 |  | 
| 10 | 
             
            from __future__ import annotations
         | 
| 11 | 
            +
             | 
| 12 | 
             
            import math
         | 
| 13 | 
            +
            from typing import Optional
         | 
| 14 |  | 
| 15 | 
             
            import torch
         | 
|  | |
| 16 | 
             
            import torch.nn.functional as F
         | 
| 17 | 
             
            import torchaudio
         | 
| 18 | 
            +
            from librosa.filters import mel as librosa_mel_fn
         | 
| 19 | 
            +
            from torch import nn
         | 
| 20 | 
             
            from x_transformers.x_transformers import apply_rotary_pos_emb
         | 
| 21 |  | 
|  | |
| 22 | 
             
            # raw wav to mel spec
         | 
| 23 |  | 
| 24 |  | 
| 25 | 
            +
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 26 | 
            +
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 30 | 
            +
                return torch.exp(x) / C
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def spectral_normalize_torch(magnitudes):
         | 
| 34 | 
            +
                return dynamic_range_compression_torch(magnitudes)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            mel_basis_cache = {}
         | 
| 38 | 
            +
            hann_window_cache = {}
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            # BigVGAN extract mel spectrogram
         | 
| 42 | 
            +
            def mel_spectrogram(
         | 
| 43 | 
            +
                y: torch.Tensor,
         | 
| 44 | 
            +
                n_fft: int,
         | 
| 45 | 
            +
                num_mels: int,
         | 
| 46 | 
            +
                sampling_rate: int,
         | 
| 47 | 
            +
                hop_size: int,
         | 
| 48 | 
            +
                win_size: int,
         | 
| 49 | 
            +
                fmin: int,
         | 
| 50 | 
            +
                fmax: int = None,
         | 
| 51 | 
            +
                center: bool = False,
         | 
| 52 | 
            +
            ) -> torch.Tensor:
         | 
| 53 | 
            +
                """Copy from https://github.com/NVIDIA/BigVGAN/tree/main"""
         | 
| 54 | 
            +
                device = y.device
         | 
| 55 | 
            +
                key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                if key not in mel_basis_cache:
         | 
| 58 | 
            +
                    mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
         | 
| 59 | 
            +
                    mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)  # TODO: why they need .float()?
         | 
| 60 | 
            +
                    hann_window_cache[key] = torch.hann_window(win_size).to(device)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                mel_basis = mel_basis_cache[key]
         | 
| 63 | 
            +
                hann_window = hann_window_cache[key]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                padding = (n_fft - hop_size) // 2
         | 
| 66 | 
            +
                y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                spec = torch.stft(
         | 
| 69 | 
            +
                    y,
         | 
| 70 | 
            +
                    n_fft,
         | 
| 71 | 
            +
                    hop_length=hop_size,
         | 
| 72 | 
            +
                    win_length=win_size,
         | 
| 73 | 
            +
                    window=hann_window,
         | 
| 74 | 
            +
                    center=center,
         | 
| 75 | 
            +
                    pad_mode="reflect",
         | 
| 76 | 
            +
                    normalized=False,
         | 
| 77 | 
            +
                    onesided=True,
         | 
| 78 | 
            +
                    return_complex=True,
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                mel_spec = torch.matmul(mel_basis, spec)
         | 
| 83 | 
            +
                mel_spec = spectral_normalize_torch(mel_spec)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                return mel_spec
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def get_bigvgan_mel_spectrogram(
         | 
| 89 | 
            +
                waveform,
         | 
| 90 | 
            +
                n_fft=1024,
         | 
| 91 | 
            +
                n_mel_channels=100,
         | 
| 92 | 
            +
                target_sample_rate=24000,
         | 
| 93 | 
            +
                hop_length=256,
         | 
| 94 | 
            +
                win_length=1024,
         | 
| 95 | 
            +
            ):
         | 
| 96 | 
            +
                return mel_spectrogram(
         | 
| 97 | 
            +
                    waveform,
         | 
| 98 | 
            +
                    n_fft,  # 1024
         | 
| 99 | 
            +
                    n_mel_channels,  # 100
         | 
| 100 | 
            +
                    target_sample_rate,  # 24000
         | 
| 101 | 
            +
                    hop_length,  # 256
         | 
| 102 | 
            +
                    win_length,  # 1024
         | 
| 103 | 
            +
                    fmin=0,  # 0
         | 
| 104 | 
            +
                    fmax=None,  # null
         | 
| 105 | 
            +
                )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def get_vocos_mel_spectrogram(
         | 
| 109 | 
            +
                waveform,
         | 
| 110 | 
            +
                n_fft=1024,
         | 
| 111 | 
            +
                n_mel_channels=100,
         | 
| 112 | 
            +
                target_sample_rate=24000,
         | 
| 113 | 
            +
                hop_length=256,
         | 
| 114 | 
            +
                win_length=1024,
         | 
| 115 | 
            +
            ):
         | 
| 116 | 
            +
                mel_stft = torchaudio.transforms.MelSpectrogram(
         | 
| 117 | 
            +
                    sample_rate=target_sample_rate,
         | 
| 118 | 
            +
                    n_fft=n_fft,
         | 
| 119 | 
            +
                    win_length=win_length,
         | 
| 120 | 
            +
                    hop_length=hop_length,
         | 
| 121 | 
            +
                    n_mels=n_mel_channels,
         | 
| 122 | 
            +
                    power=1,
         | 
| 123 | 
            +
                    center=True,
         | 
| 124 | 
            +
                    normalized=False,
         | 
| 125 | 
            +
                    norm=None,
         | 
| 126 | 
            +
                )
         | 
| 127 | 
            +
                if len(waveform.shape) == 3:
         | 
| 128 | 
            +
                    waveform = waveform.squeeze(1)  # 'b 1 nw -> b nw'
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                assert len(waveform.shape) == 2
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                mel = mel_stft(waveform)
         | 
| 133 | 
            +
                mel = mel.clamp(min=1e-5).log()
         | 
| 134 | 
            +
                return mel
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
             
            class MelSpec(nn.Module):
         | 
| 138 | 
             
                def __init__(
         | 
| 139 | 
             
                    self,
         | 
| 140 | 
            +
                    n_fft=1024,
         | 
| 141 | 
             
                    hop_length=256,
         | 
| 142 | 
             
                    win_length=1024,
         | 
| 143 | 
             
                    n_mel_channels=100,
         | 
| 144 | 
             
                    target_sample_rate=24_000,
         | 
| 145 | 
            +
                    extract_backend="vocos",
         | 
|  | |
|  | |
|  | |
| 146 | 
             
                ):
         | 
| 147 | 
             
                    super().__init__()
         | 
| 148 | 
            +
                    assert extract_backend in ["vocos", "bigvgan"], print(
         | 
| 149 | 
            +
                        "We only support two extract mel backend: vocos or bigvgan"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 150 | 
             
                    )
         | 
| 151 |  | 
| 152 | 
            +
                    self.n_fft = n_fft
         | 
| 153 | 
            +
                    self.hop_length = hop_length
         | 
| 154 | 
            +
                    self.win_length = win_length
         | 
| 155 | 
            +
                    self.n_mel_channels = n_mel_channels
         | 
| 156 | 
            +
                    self.target_sample_rate = target_sample_rate
         | 
| 157 |  | 
| 158 | 
            +
                    if extract_backend == "vocos":
         | 
| 159 | 
            +
                        self.extractor = get_vocos_mel_spectrogram
         | 
| 160 | 
            +
                    elif extract_backend == "bigvgan":
         | 
| 161 | 
            +
                        self.extractor = get_bigvgan_mel_spectrogram
         | 
| 162 |  | 
| 163 | 
            +
                    self.register_buffer("dummy", torch.tensor(0), persistent=False)
         | 
| 164 |  | 
| 165 | 
            +
                def forward(self, wav):
         | 
| 166 | 
            +
                    if self.dummy.device != wav.device:
         | 
| 167 | 
            +
                        self.to(wav.device)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    mel = self.extractor(
         | 
| 170 | 
            +
                        waveform=wav,
         | 
| 171 | 
            +
                        n_fft=self.n_fft,
         | 
| 172 | 
            +
                        n_mel_channels=self.n_mel_channels,
         | 
| 173 | 
            +
                        target_sample_rate=self.target_sample_rate,
         | 
| 174 | 
            +
                        hop_length=self.hop_length,
         | 
| 175 | 
            +
                        win_length=self.win_length,
         | 
| 176 | 
            +
                    )
         | 
| 177 |  | 
|  | |
|  | |
| 178 | 
             
                    return mel
         | 
| 179 |  | 
| 180 |  | 
    	
        src/f5_tts/model/trainer.py
    CHANGED
    
    | @@ -1,25 +1,22 @@ | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
| 3 | 
            -
            import os
         | 
| 4 | 
             
            import gc
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            import wandb
         | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            import torchaudio
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            from torch.utils.data import DataLoader, Dataset, SequentialSampler
         | 
| 12 | 
            -
            from torch.optim.lr_scheduler import LinearLR, SequentialLR
         | 
| 13 | 
            -
             | 
| 14 | 
             
            from accelerate import Accelerator
         | 
| 15 | 
             
            from accelerate.utils import DistributedDataParallelKwargs
         | 
| 16 | 
            -
             | 
| 17 | 
             
            from ema_pytorch import EMA
         | 
|  | |
|  | |
|  | |
|  | |
| 18 |  | 
| 19 | 
             
            from f5_tts.model import CFM
         | 
| 20 | 
            -
            from f5_tts.model.utils import exists, default
         | 
| 21 | 
             
            from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
         | 
| 22 | 
            -
             | 
| 23 |  | 
| 24 | 
             
            # trainer
         | 
| 25 |  | 
| @@ -49,6 +46,7 @@ class Trainer: | |
| 49 | 
             
                    accelerate_kwargs: dict = dict(),
         | 
| 50 | 
             
                    ema_kwargs: dict = dict(),
         | 
| 51 | 
             
                    bnb_optimizer: bool = False,
         | 
|  | |
| 52 | 
             
                ):
         | 
| 53 | 
             
                    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
         | 
| 54 |  | 
| @@ -110,6 +108,7 @@ class Trainer: | |
| 110 | 
             
                    self.max_samples = max_samples
         | 
| 111 | 
             
                    self.grad_accumulation_steps = grad_accumulation_steps
         | 
| 112 | 
             
                    self.max_grad_norm = max_grad_norm
         | 
|  | |
| 113 |  | 
| 114 | 
             
                    self.noise_scheduler = noise_scheduler
         | 
| 115 |  | 
| @@ -188,9 +187,10 @@ class Trainer: | |
| 188 |  | 
| 189 | 
             
                def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
         | 
| 190 | 
             
                    if self.log_samples:
         | 
| 191 | 
            -
                        from f5_tts.infer.utils_infer import  | 
|  | |
| 192 |  | 
| 193 | 
            -
                        vocoder = load_vocoder()
         | 
| 194 | 
             
                        target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
         | 
| 195 | 
             
                        log_samples_path = f"{self.checkpoint_path}/samples"
         | 
| 196 | 
             
                        os.makedirs(log_samples_path, exist_ok=True)
         | 
|  | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
|  | |
| 3 | 
             
            import gc
         | 
| 4 | 
            +
            import os
         | 
|  | |
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import torchaudio
         | 
| 8 | 
            +
            import wandb
         | 
|  | |
|  | |
|  | |
| 9 | 
             
            from accelerate import Accelerator
         | 
| 10 | 
             
            from accelerate.utils import DistributedDataParallelKwargs
         | 
|  | |
| 11 | 
             
            from ema_pytorch import EMA
         | 
| 12 | 
            +
            from torch.optim import AdamW
         | 
| 13 | 
            +
            from torch.optim.lr_scheduler import LinearLR, SequentialLR
         | 
| 14 | 
            +
            from torch.utils.data import DataLoader, Dataset, SequentialSampler
         | 
| 15 | 
            +
            from tqdm import tqdm
         | 
| 16 |  | 
| 17 | 
             
            from f5_tts.model import CFM
         | 
|  | |
| 18 | 
             
            from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
         | 
| 19 | 
            +
            from f5_tts.model.utils import default, exists
         | 
| 20 |  | 
| 21 | 
             
            # trainer
         | 
| 22 |  | 
|  | |
| 46 | 
             
                    accelerate_kwargs: dict = dict(),
         | 
| 47 | 
             
                    ema_kwargs: dict = dict(),
         | 
| 48 | 
             
                    bnb_optimizer: bool = False,
         | 
| 49 | 
            +
                    extract_backend: str = "vocos",  # "vocos" | "bigvgan"
         | 
| 50 | 
             
                ):
         | 
| 51 | 
             
                    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
         | 
| 52 |  | 
|  | |
| 108 | 
             
                    self.max_samples = max_samples
         | 
| 109 | 
             
                    self.grad_accumulation_steps = grad_accumulation_steps
         | 
| 110 | 
             
                    self.max_grad_norm = max_grad_norm
         | 
| 111 | 
            +
                    self.vocoder_name = extract_backend
         | 
| 112 |  | 
| 113 | 
             
                    self.noise_scheduler = noise_scheduler
         | 
| 114 |  | 
|  | |
| 187 |  | 
| 188 | 
             
                def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
         | 
| 189 | 
             
                    if self.log_samples:
         | 
| 190 | 
            +
                        from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
         | 
| 191 | 
            +
                                                              nfe_step, sway_sampling_coef)
         | 
| 192 |  | 
| 193 | 
            +
                        vocoder = load_vocoder(vocoder_name=self.vocoder_name)
         | 
| 194 | 
             
                        target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
         | 
| 195 | 
             
                        log_samples_path = f"{self.checkpoint_path}/samples"
         | 
| 196 | 
             
                        os.makedirs(log_samples_path, exist_ok=True)
         | 
    	
        src/f5_tts/train/train.py
    CHANGED
    
    | @@ -2,16 +2,18 @@ | |
| 2 |  | 
| 3 | 
             
            from importlib.resources import files
         | 
| 4 |  | 
| 5 | 
            -
            from f5_tts.model import CFM,  | 
| 6 | 
            -
            from f5_tts.model.utils import get_tokenizer
         | 
| 7 | 
             
            from f5_tts.model.dataset import load_dataset
         | 
| 8 | 
            -
             | 
| 9 |  | 
| 10 | 
             
            # -------------------------- Dataset Settings --------------------------- #
         | 
| 11 |  | 
| 12 | 
             
            target_sample_rate = 24000
         | 
| 13 | 
             
            n_mel_channels = 100
         | 
| 14 | 
             
            hop_length = 256
         | 
|  | |
|  | |
|  | |
| 15 |  | 
| 16 | 
             
            tokenizer = "pinyin"  # 'pinyin', 'char', or 'custom'
         | 
| 17 | 
             
            tokenizer_path = None  # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
         | 
| @@ -56,9 +58,12 @@ def main(): | |
| 56 | 
             
                vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
         | 
| 57 |  | 
| 58 | 
             
                mel_spec_kwargs = dict(
         | 
| 59 | 
            -
                     | 
| 60 | 
            -
                    n_mel_channels=n_mel_channels,
         | 
| 61 | 
             
                    hop_length=hop_length,
         | 
|  | |
|  | |
|  | |
|  | |
| 62 | 
             
                )
         | 
| 63 |  | 
| 64 | 
             
                model = CFM(
         | 
| @@ -84,6 +89,7 @@ def main(): | |
| 84 | 
             
                    wandb_resume_id=wandb_resume_id,
         | 
| 85 | 
             
                    last_per_steps=last_per_steps,
         | 
| 86 | 
             
                    log_samples=True,
         | 
|  | |
| 87 | 
             
                )
         | 
| 88 |  | 
| 89 | 
             
                train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
         | 
|  | |
| 2 |  | 
| 3 | 
             
            from importlib.resources import files
         | 
| 4 |  | 
| 5 | 
            +
            from f5_tts.model import CFM, DiT, Trainer, UNetT
         | 
|  | |
| 6 | 
             
            from f5_tts.model.dataset import load_dataset
         | 
| 7 | 
            +
            from f5_tts.model.utils import get_tokenizer
         | 
| 8 |  | 
| 9 | 
             
            # -------------------------- Dataset Settings --------------------------- #
         | 
| 10 |  | 
| 11 | 
             
            target_sample_rate = 24000
         | 
| 12 | 
             
            n_mel_channels = 100
         | 
| 13 | 
             
            hop_length = 256
         | 
| 14 | 
            +
            win_length = 1024
         | 
| 15 | 
            +
            n_fft = 1024
         | 
| 16 | 
            +
            extract_backend = "bigvgan"  # 'vocos' or 'bigvgan'
         | 
| 17 |  | 
| 18 | 
             
            tokenizer = "pinyin"  # 'pinyin', 'char', or 'custom'
         | 
| 19 | 
             
            tokenizer_path = None  # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
         | 
|  | |
| 58 | 
             
                vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
         | 
| 59 |  | 
| 60 | 
             
                mel_spec_kwargs = dict(
         | 
| 61 | 
            +
                    n_fft=n_fft,
         | 
|  | |
| 62 | 
             
                    hop_length=hop_length,
         | 
| 63 | 
            +
                    win_length=win_length,
         | 
| 64 | 
            +
                    n_mel_channels=n_mel_channels,
         | 
| 65 | 
            +
                    target_sample_rate=target_sample_rate,
         | 
| 66 | 
            +
                    extract_backend=extract_backend,
         | 
| 67 | 
             
                )
         | 
| 68 |  | 
| 69 | 
             
                model = CFM(
         | 
|  | |
| 89 | 
             
                    wandb_resume_id=wandb_resume_id,
         | 
| 90 | 
             
                    last_per_steps=last_per_steps,
         | 
| 91 | 
             
                    log_samples=True,
         | 
| 92 | 
            +
                    extract_backend=extract_backend,
         | 
| 93 | 
             
                )
         | 
| 94 |  | 
| 95 | 
             
                train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
         | 
    	
        src/third_party/BigVGAN
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            Subproject commit 7d2b454564a6c7d014227f635b7423881f14bdac
         | 
