Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/api.py +2 -2
 - src/f5_tts/infer/SHARED.md +3 -1
 - src/f5_tts/infer/utils_infer.py +23 -18
 
    	
        src/f5_tts/api.py
    CHANGED
    
    | 
         @@ -49,10 +49,10 @@ class F5TTS: 
     | 
|
| 49 | 
         
             
                    self.load_vocoder_model(vocoder_name, local_path=local_path)
         
     | 
| 50 | 
         
             
                    self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
         
     | 
| 51 | 
         | 
| 52 | 
         
            -
                def load_vocoder_model(self, vocoder_name, local_path):
         
     | 
| 53 | 
         
             
                    self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
         
     | 
| 54 | 
         | 
| 55 | 
         
            -
                def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path):
         
     | 
| 56 | 
         
             
                    if model_type == "F5-TTS":
         
     | 
| 57 | 
         
             
                        if not ckpt_file:
         
     | 
| 58 | 
         
             
                            if mel_spec_type == "vocos":
         
     | 
| 
         | 
|
| 49 | 
         
             
                    self.load_vocoder_model(vocoder_name, local_path=local_path)
         
     | 
| 50 | 
         
             
                    self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
         
     | 
| 51 | 
         | 
| 52 | 
         
            +
                def load_vocoder_model(self, vocoder_name, local_path=None):
         
     | 
| 53 | 
         
             
                    self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
         
     | 
| 54 | 
         | 
| 55 | 
         
            +
                def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path=None):
         
     | 
| 56 | 
         
             
                    if model_type == "F5-TTS":
         
     | 
| 57 | 
         
             
                        if not ckpt_file:
         
     | 
| 58 | 
         
             
                            if mel_spec_type == "vocos":
         
     | 
    	
        src/f5_tts/infer/SHARED.md
    CHANGED
    
    | 
         @@ -18,6 +18,8 @@ 
     | 
|
| 18 | 
         
             
            - [Multilingual](#multilingual)
         
     | 
| 19 | 
         
             
                - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
         
     | 
| 20 | 
         
             
            - [Mandarin](#mandarin)
         
     | 
| 
         | 
|
| 
         | 
|
| 21 | 
         
             
            - [English](#english)
         
     | 
| 22 | 
         
             
            - [French](#french)
         
     | 
| 23 | 
         
             
                - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
         
     | 
| 
         @@ -67,6 +69,6 @@ MODEL_CKPT: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.p 
     | 
|
| 67 | 
         
             
            VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
         
     | 
| 68 | 
         
             
            ```
         
     | 
| 69 | 
         | 
| 70 | 
         
            -
            - Online Inference with  
     | 
| 71 | 
         
             
            - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
         
     | 
| 72 | 
         
             
            - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
         
     | 
| 
         | 
|
| 18 | 
         
             
            - [Multilingual](#multilingual)
         
     | 
| 19 | 
         
             
                - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
         
     | 
| 20 | 
         
             
            - [Mandarin](#mandarin)
         
     | 
| 21 | 
         
            +
            - [Japanese](#japanese)
         
     | 
| 22 | 
         
            +
                - [F5-TTS Base @ pretrain/finetune @ ja](#f5-tts-base--pretrainfinetune--ja)
         
     | 
| 23 | 
         
             
            - [English](#english)
         
     | 
| 24 | 
         
             
            - [French](#french)
         
     | 
| 25 | 
         
             
                - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
         
     | 
| 
         | 
|
| 69 | 
         
             
            VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
         
     | 
| 70 | 
         
             
            ```
         
     | 
| 71 | 
         | 
| 72 | 
         
            +
            - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
         
     | 
| 73 | 
         
             
            - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
         
     | 
| 74 | 
         
             
            - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
         
     | 
    	
        src/f5_tts/infer/utils_infer.py
    CHANGED
    
    | 
         @@ -90,36 +90,41 @@ def chunk_text(text, max_chars=135): 
     | 
|
| 90 | 
         | 
| 91 | 
         | 
| 92 | 
         
             
            # load vocoder
         
     | 
| 93 | 
         
            -
            def load_vocoder(vocoder_name="vocos", is_local=False, local_path= 
     | 
| 94 | 
         
             
                if vocoder_name == "vocos":
         
     | 
| 95 | 
         
            -
                     
     | 
| 
         | 
|
| 96 | 
         
             
                        print(f"Load vocos from local path {local_path}")
         
     | 
| 97 | 
         
            -
                         
     | 
| 98 | 
         
            -
                         
     | 
| 99 | 
         
            -
                        config_path = hf_hub_download(
         
     | 
| 100 | 
         
            -
                            repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
         
     | 
| 101 | 
         
            -
                        )
         
     | 
| 102 | 
         
            -
                        model_path = hf_hub_download(
         
     | 
| 103 | 
         
            -
                            repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
         
     | 
| 104 | 
         
            -
                        )
         
     | 
| 105 | 
         
            -
                        vocoder = Vocos.from_hparams(config_path=config_path)
         
     | 
| 106 | 
         
            -
                        state_dict = torch.load(model_path, map_location="cpu")
         
     | 
| 107 | 
         
            -
                        vocoder.load_state_dict(state_dict)
         
     | 
| 108 | 
         
            -
                        vocoder = vocoder.eval().to(device)
         
     | 
| 109 | 
         
             
                    else:
         
     | 
| 110 | 
         
             
                        print("Download Vocos from huggingface charactr/vocos-mel-24khz")
         
     | 
| 111 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 112 | 
         
             
                elif vocoder_name == "bigvgan":
         
     | 
| 113 | 
         
             
                    try:
         
     | 
| 114 | 
         
             
                        from third_party.BigVGAN import bigvgan
         
     | 
| 115 | 
         
             
                    except ImportError:
         
     | 
| 116 | 
         
             
                        print("You need to follow the README to init submodule and change the BigVGAN source code.")
         
     | 
| 117 | 
         
            -
                    if is_local:
         
     | 
| 118 | 
         
             
                        """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
         
     | 
| 119 | 
         
            -
                        local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
         
     | 
| 120 | 
         
             
                        vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
         
     | 
| 121 | 
         
             
                    else:
         
     | 
| 122 | 
         
            -
                         
     | 
| 
         | 
|
| 123 | 
         | 
| 124 | 
         
             
                    vocoder.remove_weight_norm()
         
     | 
| 125 | 
         
             
                    vocoder = vocoder.eval().to(device)
         
     | 
| 
         | 
|
| 90 | 
         | 
| 91 | 
         | 
| 92 | 
         
             
            # load vocoder
         
     | 
| 93 | 
         
            +
            def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=device):
         
     | 
| 94 | 
         
             
                if vocoder_name == "vocos":
         
     | 
| 95 | 
         
            +
                    # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
         
     | 
| 96 | 
         
            +
                    if is_local and local_path is not None:
         
     | 
| 97 | 
         
             
                        print(f"Load vocos from local path {local_path}")
         
     | 
| 98 | 
         
            +
                        config_path = f"{local_path}/config.yaml"
         
     | 
| 99 | 
         
            +
                        model_path = f"{local_path}/pytorch_model.bin"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 100 | 
         
             
                    else:
         
     | 
| 101 | 
         
             
                        print("Download Vocos from huggingface charactr/vocos-mel-24khz")
         
     | 
| 102 | 
         
            +
                        repo_id = "charactr/vocos-mel-24khz"
         
     | 
| 103 | 
         
            +
                        config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml")
         
     | 
| 104 | 
         
            +
                        model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin")
         
     | 
| 105 | 
         
            +
                    vocoder = Vocos.from_hparams(config_path)
         
     | 
| 106 | 
         
            +
                    state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
         
     | 
| 107 | 
         
            +
                    from vocos.feature_extractors import EncodecFeatures
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    if isinstance(vocoder.feature_extractor, EncodecFeatures):
         
     | 
| 110 | 
         
            +
                        encodec_parameters = {
         
     | 
| 111 | 
         
            +
                            "feature_extractor.encodec." + key: value
         
     | 
| 112 | 
         
            +
                            for key, value in vocoder.feature_extractor.encodec.state_dict().items()
         
     | 
| 113 | 
         
            +
                        }
         
     | 
| 114 | 
         
            +
                        state_dict.update(encodec_parameters)
         
     | 
| 115 | 
         
            +
                    vocoder.load_state_dict(state_dict)
         
     | 
| 116 | 
         
            +
                    vocoder = vocoder.eval().to(device)
         
     | 
| 117 | 
         
             
                elif vocoder_name == "bigvgan":
         
     | 
| 118 | 
         
             
                    try:
         
     | 
| 119 | 
         
             
                        from third_party.BigVGAN import bigvgan
         
     | 
| 120 | 
         
             
                    except ImportError:
         
     | 
| 121 | 
         
             
                        print("You need to follow the README to init submodule and change the BigVGAN source code.")
         
     | 
| 122 | 
         
            +
                    if is_local and local_path is not None:
         
     | 
| 123 | 
         
             
                        """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
         
     | 
| 
         | 
|
| 124 | 
         
             
                        vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
         
     | 
| 125 | 
         
             
                    else:
         
     | 
| 126 | 
         
            +
                        local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
         
     | 
| 127 | 
         
            +
                        vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
         
     | 
| 128 | 
         | 
| 129 | 
         
             
                    vocoder.remove_weight_norm()
         
     | 
| 130 | 
         
             
                    vocoder = vocoder.eval().to(device)
         
     |