Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- api.py +13 -11
- model/backbones/dit.py +1 -1
- model/backbones/unett.py +1 -1
- model/utils_infer.py +27 -23
    	
        api.py
    CHANGED
    
    | @@ -69,6 +69,10 @@ class F5TTS: | |
| 69 | 
             
                    ref_file,
         | 
| 70 | 
             
                    ref_text,
         | 
| 71 | 
             
                    gen_text,
         | 
|  | |
|  | |
|  | |
|  | |
| 72 | 
             
                    sway_sampling_coef=-1,
         | 
| 73 | 
             
                    cfg_strength=2,
         | 
| 74 | 
             
                    nfe_step=32,
         | 
| @@ -77,23 +81,21 @@ class F5TTS: | |
| 77 | 
             
                    remove_silence=False,
         | 
| 78 | 
             
                    file_wave=None,
         | 
| 79 | 
             
                    file_spect=None,
         | 
| 80 | 
            -
                    cross_fade_duration=0.15,
         | 
| 81 | 
            -
                    show_info=print,
         | 
| 82 | 
            -
                    progress=tqdm,
         | 
| 83 | 
             
                ):
         | 
| 84 | 
             
                    wav, sr, spect = infer_process(
         | 
| 85 | 
             
                        ref_file,
         | 
| 86 | 
             
                        ref_text,
         | 
| 87 | 
             
                        gen_text,
         | 
| 88 | 
             
                        self.ema_model,
         | 
| 89 | 
            -
                         | 
| 90 | 
            -
                         | 
| 91 | 
            -
                         | 
| 92 | 
            -
                         | 
| 93 | 
            -
                        nfe_step,
         | 
| 94 | 
            -
                        cfg_strength,
         | 
| 95 | 
            -
                        sway_sampling_coef,
         | 
| 96 | 
            -
                         | 
|  | |
| 97 | 
             
                    )
         | 
| 98 |  | 
| 99 | 
             
                    if file_wave is not None:
         | 
|  | |
| 69 | 
             
                    ref_file,
         | 
| 70 | 
             
                    ref_text,
         | 
| 71 | 
             
                    gen_text,
         | 
| 72 | 
            +
                    show_info=print,
         | 
| 73 | 
            +
                    progress=tqdm,
         | 
| 74 | 
            +
                    target_rms=0.1,
         | 
| 75 | 
            +
                    cross_fade_duration=0.15,
         | 
| 76 | 
             
                    sway_sampling_coef=-1,
         | 
| 77 | 
             
                    cfg_strength=2,
         | 
| 78 | 
             
                    nfe_step=32,
         | 
|  | |
| 81 | 
             
                    remove_silence=False,
         | 
| 82 | 
             
                    file_wave=None,
         | 
| 83 | 
             
                    file_spect=None,
         | 
|  | |
|  | |
|  | |
| 84 | 
             
                ):
         | 
| 85 | 
             
                    wav, sr, spect = infer_process(
         | 
| 86 | 
             
                        ref_file,
         | 
| 87 | 
             
                        ref_text,
         | 
| 88 | 
             
                        gen_text,
         | 
| 89 | 
             
                        self.ema_model,
         | 
| 90 | 
            +
                        show_info=show_info,
         | 
| 91 | 
            +
                        progress=progress,
         | 
| 92 | 
            +
                        target_rms=target_rms,
         | 
| 93 | 
            +
                        cross_fade_duration=cross_fade_duration,
         | 
| 94 | 
            +
                        nfe_step=nfe_step,
         | 
| 95 | 
            +
                        cfg_strength=cfg_strength,
         | 
| 96 | 
            +
                        sway_sampling_coef=sway_sampling_coef,
         | 
| 97 | 
            +
                        speed=speed,
         | 
| 98 | 
            +
                        fix_duration=fix_duration,
         | 
| 99 | 
             
                    )
         | 
| 100 |  | 
| 101 | 
             
                    if file_wave is not None:
         | 
    	
        model/backbones/dit.py
    CHANGED
    
    | @@ -45,9 +45,9 @@ class TextEmbedding(nn.Module): | |
| 45 | 
             
                        self.extra_modeling = False
         | 
| 46 |  | 
| 47 | 
             
                def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
         | 
| 48 | 
            -
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 49 | 
             
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 50 | 
             
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
|  | |
| 51 | 
             
                    text = F.pad(text, (0, seq_len - text_len), value=0)
         | 
| 52 |  | 
| 53 | 
             
                    if drop_text:  # cfg for text
         | 
|  | |
| 45 | 
             
                        self.extra_modeling = False
         | 
| 46 |  | 
| 47 | 
             
                def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
         | 
|  | |
| 48 | 
             
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 49 | 
             
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
| 50 | 
            +
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 51 | 
             
                    text = F.pad(text, (0, seq_len - text_len), value=0)
         | 
| 52 |  | 
| 53 | 
             
                    if drop_text:  # cfg for text
         | 
    	
        model/backbones/unett.py
    CHANGED
    
    | @@ -48,9 +48,9 @@ class TextEmbedding(nn.Module): | |
| 48 | 
             
                        self.extra_modeling = False
         | 
| 49 |  | 
| 50 | 
             
                def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
         | 
| 51 | 
            -
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 52 | 
             
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 53 | 
             
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
|  | |
| 54 | 
             
                    text = F.pad(text, (0, seq_len - text_len), value=0)
         | 
| 55 |  | 
| 56 | 
             
                    if drop_text:  # cfg for text
         | 
|  | |
| 48 | 
             
                        self.extra_modeling = False
         | 
| 49 |  | 
| 50 | 
             
                def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
         | 
|  | |
| 51 | 
             
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 52 | 
             
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
| 53 | 
            +
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 54 | 
             
                    text = F.pad(text, (0, seq_len - text_len), value=0)
         | 
| 55 |  | 
| 56 | 
             
                    if drop_text:  # cfg for text
         | 
    	
        model/utils_infer.py
    CHANGED
    
    | @@ -31,12 +31,13 @@ target_sample_rate = 24000 | |
| 31 | 
             
            n_mel_channels = 100
         | 
| 32 | 
             
            hop_length = 256
         | 
| 33 | 
             
            target_rms = 0.1
         | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
            #  | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
|  | |
| 40 |  | 
| 41 | 
             
            # -----------------------------------------
         | 
| 42 |  | 
| @@ -107,7 +108,7 @@ def initialize_asr_pipeline(device=device): | |
| 107 | 
             
            # load model for inference
         | 
| 108 |  | 
| 109 |  | 
| 110 | 
            -
            def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method= | 
| 111 | 
             
                if vocab_file == "":
         | 
| 112 | 
             
                    vocab_file = "Emilia_ZH_EN"
         | 
| 113 | 
             
                    tokenizer = "pinyin"
         | 
| @@ -192,14 +193,15 @@ def infer_process( | |
| 192 | 
             
                ref_text,
         | 
| 193 | 
             
                gen_text,
         | 
| 194 | 
             
                model_obj,
         | 
| 195 | 
            -
                cross_fade_duration=0.15,
         | 
| 196 | 
            -
                speed=1.0,
         | 
| 197 | 
             
                show_info=print,
         | 
| 198 | 
             
                progress=tqdm,
         | 
| 199 | 
            -
                 | 
| 200 | 
            -
                 | 
| 201 | 
            -
                 | 
| 202 | 
            -
                 | 
|  | |
|  | |
|  | |
| 203 | 
             
            ):
         | 
| 204 | 
             
                # Split the input text into batches
         | 
| 205 | 
             
                audio, sr = torchaudio.load(ref_audio)
         | 
| @@ -214,13 +216,14 @@ def infer_process( | |
| 214 | 
             
                    ref_text,
         | 
| 215 | 
             
                    gen_text_batches,
         | 
| 216 | 
             
                    model_obj,
         | 
| 217 | 
            -
                     | 
| 218 | 
            -
                     | 
| 219 | 
            -
                     | 
| 220 | 
            -
                    nfe_step,
         | 
| 221 | 
            -
                    cfg_strength,
         | 
| 222 | 
            -
                    sway_sampling_coef,
         | 
| 223 | 
            -
                     | 
|  | |
| 224 | 
             
                )
         | 
| 225 |  | 
| 226 |  | 
| @@ -232,12 +235,13 @@ def infer_batch_process( | |
| 232 | 
             
                ref_text,
         | 
| 233 | 
             
                gen_text_batches,
         | 
| 234 | 
             
                model_obj,
         | 
| 235 | 
            -
                cross_fade_duration=0.15,
         | 
| 236 | 
            -
                speed=1,
         | 
| 237 | 
             
                progress=tqdm,
         | 
|  | |
|  | |
| 238 | 
             
                nfe_step=32,
         | 
| 239 | 
             
                cfg_strength=2.0,
         | 
| 240 | 
             
                sway_sampling_coef=-1,
         | 
|  | |
| 241 | 
             
                fix_duration=None,
         | 
| 242 | 
             
            ):
         | 
| 243 | 
             
                audio, sr = ref_audio
         | 
| @@ -262,11 +266,11 @@ def infer_batch_process( | |
| 262 | 
             
                    text_list = [ref_text + gen_text]
         | 
| 263 | 
             
                    final_text_list = convert_char_to_pinyin(text_list)
         | 
| 264 |  | 
|  | |
| 265 | 
             
                    if fix_duration is not None:
         | 
| 266 | 
             
                        duration = int(fix_duration * target_sample_rate / hop_length)
         | 
| 267 | 
             
                    else:
         | 
| 268 | 
             
                        # Calculate duration
         | 
| 269 | 
            -
                        ref_audio_len = audio.shape[-1] // hop_length
         | 
| 270 | 
             
                        ref_text_len = len(ref_text.encode("utf-8"))
         | 
| 271 | 
             
                        gen_text_len = len(gen_text.encode("utf-8"))
         | 
| 272 | 
             
                        duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
         | 
|  | |
| 31 | 
             
            n_mel_channels = 100
         | 
| 32 | 
             
            hop_length = 256
         | 
| 33 | 
             
            target_rms = 0.1
         | 
| 34 | 
            +
            cross_fade_duration = 0.15
         | 
| 35 | 
            +
            ode_method = "euler"
         | 
| 36 | 
            +
            nfe_step = 32  # 16, 32
         | 
| 37 | 
            +
            cfg_strength = 2.0
         | 
| 38 | 
            +
            sway_sampling_coef = -1.0
         | 
| 39 | 
            +
            speed = 1.0
         | 
| 40 | 
            +
            fix_duration = None
         | 
| 41 |  | 
| 42 | 
             
            # -----------------------------------------
         | 
| 43 |  | 
|  | |
| 108 | 
             
            # load model for inference
         | 
| 109 |  | 
| 110 |  | 
| 111 | 
            +
            def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
         | 
| 112 | 
             
                if vocab_file == "":
         | 
| 113 | 
             
                    vocab_file = "Emilia_ZH_EN"
         | 
| 114 | 
             
                    tokenizer = "pinyin"
         | 
|  | |
| 193 | 
             
                ref_text,
         | 
| 194 | 
             
                gen_text,
         | 
| 195 | 
             
                model_obj,
         | 
|  | |
|  | |
| 196 | 
             
                show_info=print,
         | 
| 197 | 
             
                progress=tqdm,
         | 
| 198 | 
            +
                target_rms=target_rms,
         | 
| 199 | 
            +
                cross_fade_duration=cross_fade_duration,
         | 
| 200 | 
            +
                nfe_step=nfe_step,
         | 
| 201 | 
            +
                cfg_strength=cfg_strength,
         | 
| 202 | 
            +
                sway_sampling_coef=sway_sampling_coef,
         | 
| 203 | 
            +
                speed=speed,
         | 
| 204 | 
            +
                fix_duration=fix_duration,
         | 
| 205 | 
             
            ):
         | 
| 206 | 
             
                # Split the input text into batches
         | 
| 207 | 
             
                audio, sr = torchaudio.load(ref_audio)
         | 
|  | |
| 216 | 
             
                    ref_text,
         | 
| 217 | 
             
                    gen_text_batches,
         | 
| 218 | 
             
                    model_obj,
         | 
| 219 | 
            +
                    progress=progress,
         | 
| 220 | 
            +
                    target_rms=target_rms,
         | 
| 221 | 
            +
                    cross_fade_duration=cross_fade_duration,
         | 
| 222 | 
            +
                    nfe_step=nfe_step,
         | 
| 223 | 
            +
                    cfg_strength=cfg_strength,
         | 
| 224 | 
            +
                    sway_sampling_coef=sway_sampling_coef,
         | 
| 225 | 
            +
                    speed=speed,
         | 
| 226 | 
            +
                    fix_duration=fix_duration,
         | 
| 227 | 
             
                )
         | 
| 228 |  | 
| 229 |  | 
|  | |
| 235 | 
             
                ref_text,
         | 
| 236 | 
             
                gen_text_batches,
         | 
| 237 | 
             
                model_obj,
         | 
|  | |
|  | |
| 238 | 
             
                progress=tqdm,
         | 
| 239 | 
            +
                target_rms=0.1,
         | 
| 240 | 
            +
                cross_fade_duration=0.15,
         | 
| 241 | 
             
                nfe_step=32,
         | 
| 242 | 
             
                cfg_strength=2.0,
         | 
| 243 | 
             
                sway_sampling_coef=-1,
         | 
| 244 | 
            +
                speed=1,
         | 
| 245 | 
             
                fix_duration=None,
         | 
| 246 | 
             
            ):
         | 
| 247 | 
             
                audio, sr = ref_audio
         | 
|  | |
| 266 | 
             
                    text_list = [ref_text + gen_text]
         | 
| 267 | 
             
                    final_text_list = convert_char_to_pinyin(text_list)
         | 
| 268 |  | 
| 269 | 
            +
                    ref_audio_len = audio.shape[-1] // hop_length
         | 
| 270 | 
             
                    if fix_duration is not None:
         | 
| 271 | 
             
                        duration = int(fix_duration * target_sample_rate / hop_length)
         | 
| 272 | 
             
                    else:
         | 
| 273 | 
             
                        # Calculate duration
         | 
|  | |
| 274 | 
             
                        ref_text_len = len(ref_text.encode("utf-8"))
         | 
| 275 | 
             
                        gen_text_len = len(gen_text.encode("utf-8"))
         | 
| 276 | 
             
                        duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
         | 
