Yushen CHEN rsxdalv Adam Kessel Roberts Slisans commited on
Commit
dadf554
Β·
1 Parent(s): f4772fe

convert to pkg, reorganize repo (#228)

Browse files

* group files in f5_tts directory

* add setup.py

* use global imports

* simplify demo

* add install directions for library mode

* fix old huggingface_hub version constraint

* move finetune to package

* change imports to f5_tts.model

* bump version

* fix bad merge

* Update inference-cli.py

* fix HF space

* reformat

* fix utils.py vocab.txt import

* fix format

* adapt README for f5_tts package structure

* simplify app.py

* add gradio.Dockerfile and workflow

* refactored for pyproject.toml

* refactored for pyproject.toml

* added in reference to packaged files

* use fork for testing docker image

* added in reference to packaged files

* minor tweaks

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* refactor eval_infer_batch.py

* fix typo

* added eval_infer_batch to scripts

---------

Co-authored-by: Roberts Slisans <[email protected]>
Co-authored-by: Adam Kessel <[email protected]>
Co-authored-by: Roberts Slisans <[email protected]>

Files changed (38) hide show
  1. .github/workflows/publish-docker-image.yaml +61 -0
  2. README.md +42 -6
  3. app.py +3 -0
  4. gradio.Dockerfile +27 -0
  5. model/__init__.py +0 -10
  6. pyproject.toml +52 -0
  7. scripts/eval_infer_batch.py +0 -198
  8. api.py β†’ src/f5_tts/api.py +4 -4
  9. {data β†’ src/f5_tts/data}/Emilia_ZH_EN_pinyin/vocab.txt +0 -0
  10. inference-cli.toml β†’ src/f5_tts/data/inference-cli.toml +0 -0
  11. {data β†’ src/f5_tts/data}/librispeech_pc_test_clean_cross_sentence.lst +0 -0
  12. finetune-cli.py β†’ src/f5_tts/finetune_cli.py +3 -3
  13. finetune_gradio.py β†’ src/f5_tts/finetune_gradio.py +2 -2
  14. gradio_app.py β†’ src/f5_tts/gradio_app.py +3 -3
  15. inference-cli.py β†’ src/f5_tts/inference_cli.py +10 -5
  16. src/f5_tts/model/__init__.py +10 -0
  17. {model β†’ src/f5_tts/model}/backbones/README.md +0 -0
  18. {model β†’ src/f5_tts/model}/backbones/dit.py +1 -1
  19. {model β†’ src/f5_tts/model}/backbones/mmdit.py +1 -1
  20. {model β†’ src/f5_tts/model}/backbones/unett.py +1 -1
  21. {model β†’ src/f5_tts/model}/cfm.py +2 -2
  22. {model β†’ src/f5_tts/model}/dataset.py +2 -2
  23. {model β†’ src/f5_tts/model}/ecapa_tdnn.py +0 -0
  24. {model β†’ src/f5_tts/model}/modules.py +0 -0
  25. {model β†’ src/f5_tts/model}/trainer.py +3 -3
  26. {model β†’ src/f5_tts/model}/utils.py +5 -3
  27. {model β†’ src/f5_tts/model}/utils_infer.py +2 -2
  28. {scripts β†’ src/f5_tts/scripts}/count_max_epoch.py +0 -0
  29. {scripts β†’ src/f5_tts/scripts}/count_params_gflops.py +1 -1
  30. src/f5_tts/scripts/eval_infer_batch.py +204 -0
  31. {scripts β†’ src/f5_tts/scripts}/eval_infer_batch.sh +0 -0
  32. {scripts β†’ src/f5_tts/scripts}/eval_librispeech_test_clean.py +1 -1
  33. {scripts β†’ src/f5_tts/scripts}/eval_seedtts_testset.py +1 -1
  34. {scripts β†’ src/f5_tts/scripts}/prepare_csv_wavs.py +1 -1
  35. {scripts β†’ src/f5_tts/scripts}/prepare_emilia.py +1 -1
  36. {scripts β†’ src/f5_tts/scripts}/prepare_wenetspeech4tts.py +1 -1
  37. speech_edit.py β†’ src/f5_tts/speech_edit.py +2 -2
  38. train.py β†’ src/f5_tts/train.py +3 -3
.github/workflows/publish-docker-image.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Create and publish a Docker image
2
+
3
+ # Configures this workflow to run every time a change is pushed to the branch called `release`.
4
+ on:
5
+ push:
6
+ branches: ['main']
7
+
8
+ # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
9
+ env:
10
+ REGISTRY: ghcr.io
11
+ IMAGE_NAME: ${{ github.repository }}
12
+
13
+ # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
14
+ jobs:
15
+ build-and-push-image:
16
+ runs-on: ubuntu-latest
17
+ # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
18
+ permissions:
19
+ contents: read
20
+ packages: write
21
+ #
22
+ steps:
23
+ - name: Checkout repository
24
+ uses: actions/checkout@v4
25
+ - name: Free Up GitHub Actions Ubuntu Runner Disk Space πŸ”§
26
+ uses: jlumbroso/free-disk-space@main
27
+ with:
28
+ # This might remove tools that are actually needed, if set to "true" but frees about 6 GB
29
+ tool-cache: false
30
+
31
+ # All of these default to true, but feel free to set to "false" if necessary for your workflow
32
+ android: true
33
+ dotnet: true
34
+ haskell: true
35
+ large-packages: false
36
+ swap-storage: false
37
+ docker-images: false
38
+ # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
39
+ - name: Log in to the Container registry
40
+ uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
41
+ with:
42
+ registry: ${{ env.REGISTRY }}
43
+ username: ${{ github.actor }}
44
+ password: ${{ secrets.GITHUB_TOKEN }}
45
+ # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
46
+ - name: Extract metadata (tags, labels) for Docker
47
+ id: meta
48
+ uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
49
+ with:
50
+ images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
51
+ # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
52
+ # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
53
+ # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
54
+ - name: Build and push Docker image
55
+ uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
56
+ with:
57
+ context: .
58
+ file: ./gradio.Dockerfile
59
+ push: true
60
+ tags: ${{ steps.meta.outputs.tags }}
61
+ labels: ${{ steps.meta.outputs.labels }}
README.md CHANGED
@@ -63,11 +63,35 @@ pre-commit run --all-files
63
  Note: Some model components have linting exceptions for E722 to accommodate tensor notation
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ## Prepare Dataset
67
 
68
- Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
69
 
70
  ```bash
 
 
 
71
  # prepare custom dataset up to your need
72
  # download corresponding dataset first, and fill in the path in scripts
73
 
@@ -83,6 +107,9 @@ python scripts/prepare_wenetspeech4tts.py
83
  Once your datasets are prepared, you can start the training process.
84
 
85
  ```bash
 
 
 
86
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
87
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
88
  accelerate config
@@ -90,7 +117,7 @@ accelerate launch train.py
90
  ```
91
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
92
 
93
- Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
94
 
95
  ### Wandb Logging
96
 
@@ -136,6 +163,9 @@ for change model use `--ckpt_file` to specify the model you want to load,
136
  for change vocab.txt use `--vocab_file` to provide your vocab.txt file.
137
 
138
  ```bash
 
 
 
139
  python inference-cli.py \
140
  --model "F5-TTS" \
141
  --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
@@ -161,19 +191,19 @@ Currently supported features:
161
  You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
162
 
163
  ```bash
164
- python gradio_app.py
165
  ```
166
 
167
  You can specify the port/host:
168
 
169
  ```bash
170
- python gradio_app.py --port 7860 --host 0.0.0.0
171
  ```
172
 
173
  Or launch a share link:
174
 
175
  ```bash
176
- python gradio_app.py --share
177
  ```
178
 
179
  ### Speech Editing
@@ -181,7 +211,7 @@ python gradio_app.py --share
181
  To test speech editing capabilities, use the following command.
182
 
183
  ```bash
184
- python speech_edit.py
185
  ```
186
 
187
  ## Evaluation
@@ -199,6 +229,9 @@ python speech_edit.py
199
  To run batch inference for evaluations, execute the following commands:
200
 
201
  ```bash
 
 
 
202
  # batch inference for evaluations
203
  accelerate config # if not set before
204
  bash scripts/eval_infer_batch.sh
@@ -234,6 +267,9 @@ pip install faster-whisper==0.10.1
234
 
235
  Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
236
  ```bash
 
 
 
237
  # Evaluation for Seed-TTS test set
238
  python scripts/eval_seedtts_testset.py
239
 
 
63
  Note: Some model components have linting exceptions for E722 to accommodate tensor notation
64
 
65
 
66
+ ### As a pip package
67
+
68
+ ```bash
69
+ pip install git+https://github.com/SWivid/F5-TTS.git
70
+ ```
71
+
72
+ ```python
73
+ import gradio as gr
74
+ from f5_tts.gradio_app import app
75
+
76
+ with gr.Blocks() as main_app:
77
+ gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
78
+
79
+ # ... other Gradio components
80
+
81
+ app.render()
82
+
83
+ main_app.launch()
84
+
85
+ ```
86
+
87
  ## Prepare Dataset
88
 
89
+ Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `f5_tts/model/dataset.py`.
90
 
91
  ```bash
92
+ # switch to the main directory
93
+ cd f5_tts
94
+
95
  # prepare custom dataset up to your need
96
  # download corresponding dataset first, and fill in the path in scripts
97
 
 
107
  Once your datasets are prepared, you can start the training process.
108
 
109
  ```bash
110
+ # switch to the main directory
111
+ cd f5_tts
112
+
113
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
114
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
115
  accelerate config
 
117
  ```
118
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
119
 
120
+ Gradio UI finetuning with `f5_tts/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
121
 
122
  ### Wandb Logging
123
 
 
163
  for change vocab.txt use `--vocab_file` to provide your vocab.txt file.
164
 
165
  ```bash
166
+ # switch to the main directory
167
+ cd f5_tts
168
+
169
  python inference-cli.py \
170
  --model "F5-TTS" \
171
  --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
 
191
  You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
192
 
193
  ```bash
194
+ python f5_tts/gradio_app.py
195
  ```
196
 
197
  You can specify the port/host:
198
 
199
  ```bash
200
+ python f5_tts/gradio_app.py --port 7860 --host 0.0.0.0
201
  ```
202
 
203
  Or launch a share link:
204
 
205
  ```bash
206
+ python f5_tts/gradio_app.py --share
207
  ```
208
 
209
  ### Speech Editing
 
211
  To test speech editing capabilities, use the following command.
212
 
213
  ```bash
214
+ python f5_tts/speech_edit.py
215
  ```
216
 
217
  ## Evaluation
 
229
  To run batch inference for evaluations, execute the following commands:
230
 
231
  ```bash
232
+ # switch to the main directory
233
+ cd f5_tts
234
+
235
  # batch inference for evaluations
236
  accelerate config # if not set before
237
  bash scripts/eval_infer_batch.sh
 
267
 
268
  Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
269
  ```bash
270
+ # switch to the main directory
271
+ cd f5_tts
272
+
273
  # Evaluation for Seed-TTS test set
274
  python scripts/eval_seedtts_testset.py
275
 
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from f5_tts.gradio_app import app
2
+
3
+ app.queue().launch()
gradio.Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
2
+
3
+ USER root
4
+
5
+ ARG DEBIAN_FRONTEND=noninteractive
6
+
7
+ LABEL github_repo="https://github.com/rsxdalv/F5-TTS"
8
+
9
+ RUN set -x \
10
+ && apt-get update \
11
+ && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12
+ && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
13
+ && rm -rf /var/lib/apt/lists/* \
14
+ && apt-get clean
15
+
16
+ WORKDIR /workspace
17
+
18
+ RUN git clone https://github.com/rsxdalv/F5-TTS.git \
19
+ && cd F5-TTS \
20
+ && pip install --no-cache-dir -r requirements.txt
21
+
22
+ ENV SHELL=/bin/bash
23
+
24
+ WORKDIR /workspace/F5-TTS/f5_tts
25
+
26
+ EXPOSE 7860
27
+ CMD python gradio_app.py
model/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- from model.cfm import CFM
2
-
3
- from model.backbones.unett import UNetT
4
- from model.backbones.dit import DiT
5
- from model.backbones.mmdit import MMDiT
6
-
7
- from model.trainer import Trainer
8
-
9
-
10
- __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "f5-tts"
7
+ dynamic = ["version"]
8
+ description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
+ readme = "README.md"
10
+ classifiers = [
11
+ "License :: OSI Approved :: MIT License",
12
+ "Operating System :: OS Independent",
13
+ "Programming Language :: Python :: 3",
14
+ ]
15
+ dependencies = [
16
+ "accelerate>=0.33.0",
17
+ "cached_path @ git+https://github.com/rsxdalv/cached_path@main",
18
+ "click",
19
+ "datasets",
20
+ "einops>=0.8.0",
21
+ "einx>=0.3.0",
22
+ "ema_pytorch>=0.5.2",
23
+ "gradio",
24
+ "jieba",
25
+ "librosa",
26
+ "matplotlib",
27
+ "numpy<=1.26.4",
28
+ "pydub",
29
+ "pypinyin",
30
+ "safetensors",
31
+ "soundfile",
32
+ "tomli",
33
+ "torch>=2.0.0",
34
+ "torchaudio>=2.0.0",
35
+ "torchdiffeq",
36
+ "tqdm>=4.65.0",
37
+ "transformers",
38
+ "vocos",
39
+ "wandb",
40
+ "x_transformers>=1.31.14",
41
+ ]
42
+
43
+ [[project.authors]]
44
+ name = "Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen"
45
+
46
+ [project.urls]
47
+ Homepage = "https://github.com/SWivid/F5-TTS"
48
+
49
+ [project.scripts]
50
+ "finetune-cli" = "f5_tts.finetune_cli:main"
51
+ "inference-cli" = "f5_tts.inference_cli:main"
52
+ "eval_infer_batch" = "f5_tts.scripts.eval_infer_batch:main"
scripts/eval_infer_batch.py DELETED
@@ -1,198 +0,0 @@
1
- import sys
2
- import os
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- import time
7
- import random
8
- from tqdm import tqdm
9
- import argparse
10
-
11
- import torch
12
- import torchaudio
13
- from accelerate import Accelerator
14
- from vocos import Vocos
15
-
16
- from model import CFM, UNetT, DiT
17
- from model.utils import (
18
- load_checkpoint,
19
- get_tokenizer,
20
- get_seedtts_testset_metainfo,
21
- get_librispeech_test_clean_metainfo,
22
- get_inference_prompt,
23
- )
24
-
25
- accelerator = Accelerator()
26
- device = f"cuda:{accelerator.process_index}"
27
-
28
-
29
- # --------------------- Dataset Settings -------------------- #
30
-
31
- target_sample_rate = 24000
32
- n_mel_channels = 100
33
- hop_length = 256
34
- target_rms = 0.1
35
-
36
- tokenizer = "pinyin"
37
-
38
-
39
- # ---------------------- infer setting ---------------------- #
40
-
41
- parser = argparse.ArgumentParser(description="batch inference")
42
-
43
- parser.add_argument("-s", "--seed", default=None, type=int)
44
- parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
45
- parser.add_argument("-n", "--expname", required=True)
46
- parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
47
-
48
- parser.add_argument("-nfe", "--nfestep", default=32, type=int)
49
- parser.add_argument("-o", "--odemethod", default="euler")
50
- parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
51
-
52
- parser.add_argument("-t", "--testset", required=True)
53
-
54
- args = parser.parse_args()
55
-
56
-
57
- seed = args.seed
58
- dataset_name = args.dataset
59
- exp_name = args.expname
60
- ckpt_step = args.ckptstep
61
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
62
-
63
- nfe_step = args.nfestep
64
- ode_method = args.odemethod
65
- sway_sampling_coef = args.swaysampling
66
-
67
- testset = args.testset
68
-
69
-
70
- infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
71
- cfg_strength = 2.0
72
- speed = 1.0
73
- use_truth_duration = False
74
- no_ref_audio = False
75
-
76
-
77
- if exp_name == "F5TTS_Base":
78
- model_cls = DiT
79
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
-
81
- elif exp_name == "E2TTS_Base":
82
- model_cls = UNetT
83
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
84
-
85
-
86
- if testset == "ls_pc_test_clean":
87
- metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
88
- librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
89
- metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
90
-
91
- elif testset == "seedtts_test_zh":
92
- metalst = "data/seedtts_testset/zh/meta.lst"
93
- metainfo = get_seedtts_testset_metainfo(metalst)
94
-
95
- elif testset == "seedtts_test_en":
96
- metalst = "data/seedtts_testset/en/meta.lst"
97
- metainfo = get_seedtts_testset_metainfo(metalst)
98
-
99
-
100
- # path to save genereted wavs
101
- if seed is None:
102
- seed = random.randint(-10000, 10000)
103
- output_dir = (
104
- f"results/{exp_name}_{ckpt_step}/{testset}/"
105
- f"seed{seed}_{ode_method}_nfe{nfe_step}"
106
- f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107
- f"_cfg{cfg_strength}_speed{speed}"
108
- f"{'_gt-dur' if use_truth_duration else ''}"
109
- f"{'_no-ref-audio' if no_ref_audio else ''}"
110
- )
111
-
112
-
113
- # -------------------------------------------------#
114
-
115
- use_ema = True
116
-
117
- prompts_all = get_inference_prompt(
118
- metainfo,
119
- speed=speed,
120
- tokenizer=tokenizer,
121
- target_sample_rate=target_sample_rate,
122
- n_mel_channels=n_mel_channels,
123
- hop_length=hop_length,
124
- target_rms=target_rms,
125
- use_truth_duration=use_truth_duration,
126
- infer_batch_size=infer_batch_size,
127
- )
128
-
129
- # Vocoder model
130
- local = False
131
- if local:
132
- vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
133
- vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
134
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
135
- vocos.load_state_dict(state_dict)
136
- vocos.eval()
137
- else:
138
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
139
-
140
- # Tokenizer
141
- vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
142
-
143
- # Model
144
- model = CFM(
145
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
146
- mel_spec_kwargs=dict(
147
- target_sample_rate=target_sample_rate,
148
- n_mel_channels=n_mel_channels,
149
- hop_length=hop_length,
150
- ),
151
- odeint_kwargs=dict(
152
- method=ode_method,
153
- ),
154
- vocab_char_map=vocab_char_map,
155
- ).to(device)
156
-
157
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
158
-
159
- if not os.path.exists(output_dir) and accelerator.is_main_process:
160
- os.makedirs(output_dir)
161
-
162
- # start batch inference
163
- accelerator.wait_for_everyone()
164
- start = time.time()
165
-
166
- with accelerator.split_between_processes(prompts_all) as prompts:
167
- for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
168
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
169
- ref_mels = ref_mels.to(device)
170
- ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
171
- total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
172
-
173
- # Inference
174
- with torch.inference_mode():
175
- generated, _ = model.sample(
176
- cond=ref_mels,
177
- text=final_text_list,
178
- duration=total_mel_lens,
179
- lens=ref_mel_lens,
180
- steps=nfe_step,
181
- cfg_strength=cfg_strength,
182
- sway_sampling_coef=sway_sampling_coef,
183
- no_ref_audio=no_ref_audio,
184
- seed=seed,
185
- )
186
- # Final result
187
- for i, gen in enumerate(generated):
188
- gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
189
- gen_mel_spec = gen.permute(0, 2, 1)
190
- generated_wave = vocos.decode(gen_mel_spec.cpu())
191
- if ref_rms_list[i] < target_rms:
192
- generated_wave = generated_wave * ref_rms_list[i] / target_rms
193
- torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
194
-
195
- accelerator.wait_for_everyone()
196
- if accelerator.is_main_process:
197
- timediff = time.time() - start
198
- print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api.py β†’ src/f5_tts/api.py RENAMED
@@ -3,11 +3,11 @@ import torch
3
  import tqdm
4
  from cached_path import cached_path
5
 
6
- from model import DiT, UNetT
7
- from model.utils import save_spectrogram
8
 
9
- from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
10
- from model.utils import seed_everything
11
  import random
12
  import sys
13
 
 
3
  import tqdm
4
  from cached_path import cached_path
5
 
6
+ from f5_tts.model import DiT, UNetT
7
+ from f5_tts.model.utils import save_spectrogram
8
 
9
+ from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
10
+ from f5_tts.model.utils import seed_everything
11
  import random
12
  import sys
13
 
{data β†’ src/f5_tts/data}/Emilia_ZH_EN_pinyin/vocab.txt RENAMED
File without changes
inference-cli.toml β†’ src/f5_tts/data/inference-cli.toml RENAMED
File without changes
{data β†’ src/f5_tts/data}/librispeech_pc_test_clean_cross_sentence.lst RENAMED
File without changes
finetune-cli.py β†’ src/f5_tts/finetune_cli.py RENAMED
@@ -1,7 +1,7 @@
1
  import argparse
2
- from model import CFM, UNetT, DiT, Trainer
3
- from model.utils import get_tokenizer
4
- from model.dataset import load_dataset
5
  from cached_path import cached_path
6
  import shutil
7
  import os
 
1
  import argparse
2
+ from f5_tts.model import CFM, UNetT, DiT, Trainer
3
+ from f5_tts.model.utils import get_tokenizer
4
+ from f5_tts.model.dataset import load_dataset
5
  from cached_path import cached_path
6
  import shutil
7
  import os
finetune_gradio.py β†’ src/f5_tts/finetune_gradio.py RENAMED
@@ -17,14 +17,14 @@ import shutil
17
  import time
18
 
19
  import json
20
- from model.utils import convert_char_to_pinyin
21
  import signal
22
  import psutil
23
  import platform
24
  import subprocess
25
  from datasets.arrow_writer import ArrowWriter
26
  from datasets import Dataset as Dataset_
27
- from api import F5TTS
28
 
29
 
30
  training_process = None
 
17
  import time
18
 
19
  import json
20
+ from f5_tts.model.utils import convert_char_to_pinyin
21
  import signal
22
  import psutil
23
  import platform
24
  import subprocess
25
  from datasets.arrow_writer import ArrowWriter
26
  from datasets import Dataset as Dataset_
27
+ from f5_tts.api import F5TTS
28
 
29
 
30
  training_process = None
gradio_app.py β†’ src/f5_tts/gradio_app.py RENAMED
@@ -27,11 +27,11 @@ def gpu_decorator(func):
27
  return func
28
 
29
 
30
- from model import DiT, UNetT
31
- from model.utils import (
32
  save_spectrogram,
33
  )
34
- from model.utils_infer import (
35
  load_vocoder,
36
  load_model,
37
  preprocess_ref_audio_text,
 
27
  return func
28
 
29
 
30
+ from f5_tts.model import DiT, UNetT
31
+ from f5_tts.model.utils import (
32
  save_spectrogram,
33
  )
34
+ from f5_tts.model.utils_infer import (
35
  load_vocoder,
36
  load_model,
37
  preprocess_ref_audio_text,
inference-cli.py β†’ src/f5_tts/inference_cli.py RENAMED
@@ -1,15 +1,17 @@
1
  import argparse
2
  import codecs
3
  import re
 
4
  from pathlib import Path
 
5
 
6
  import numpy as np
7
  import soundfile as sf
8
  import tomli
9
  from cached_path import cached_path
10
 
11
- from model import DiT, UNetT
12
- from model.utils_infer import (
13
  load_vocoder,
14
  load_model,
15
  preprocess_ref_audio_text,
@@ -26,8 +28,8 @@ parser = argparse.ArgumentParser(
26
  parser.add_argument(
27
  "-c",
28
  "--config",
29
- help="Configuration file. Default=cli-config.toml",
30
- default="inference-cli.toml",
31
  )
32
  parser.add_argument(
33
  "-m",
@@ -166,5 +168,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
166
  remove_silence_for_generated_wav(f.name)
167
  print(f.name)
168
 
 
 
169
 
170
- main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
 
 
1
  import argparse
2
  import codecs
3
  import re
4
+ import os
5
  from pathlib import Path
6
+ from importlib.resources import files
7
 
8
  import numpy as np
9
  import soundfile as sf
10
  import tomli
11
  from cached_path import cached_path
12
 
13
+ from f5_tts.model import DiT, UNetT
14
+ from f5_tts.model.utils_infer import (
15
  load_vocoder,
16
  load_model,
17
  preprocess_ref_audio_text,
 
28
  parser.add_argument(
29
  "-c",
30
  "--config",
31
+ help="Configuration file. Default=inference-cli.toml",
32
+ default=os.path.join(files('f5_tts').joinpath('data'), 'inference-cli.toml')
33
  )
34
  parser.add_argument(
35
  "-m",
 
168
  remove_silence_for_generated_wav(f.name)
169
  print(f.name)
170
 
171
+ def main():
172
+ main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
173
 
174
+ if __name__ == "__main__":
175
+ main()
src/f5_tts/model/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from f5_tts.model.cfm import CFM
2
+
3
+ from f5_tts.model.backbones.unett import UNetT
4
+ from f5_tts.model.backbones.dit import DiT
5
+ from f5_tts.model.backbones.mmdit import MMDiT
6
+
7
+ from f5_tts.model.trainer import Trainer
8
+
9
+
10
+ __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
{model β†’ src/f5_tts/model}/backbones/README.md RENAMED
File without changes
{model β†’ src/f5_tts/model}/backbones/dit.py RENAMED
@@ -15,7 +15,7 @@ import torch.nn.functional as F
15
 
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
- from model.modules import (
19
  TimestepEmbedding,
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
 
15
 
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
+ from f5_tts.model.modules import (
19
  TimestepEmbedding,
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
{model β†’ src/f5_tts/model}/backbones/mmdit.py RENAMED
@@ -14,7 +14,7 @@ from torch import nn
14
 
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
- from model.modules import (
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
 
14
 
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
+ from f5_tts.model.modules import (
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
{model β†’ src/f5_tts/model}/backbones/unett.py RENAMED
@@ -17,7 +17,7 @@ import torch.nn.functional as F
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
- from model.modules import (
21
  TimestepEmbedding,
22
  ConvNeXtV2Block,
23
  ConvPositionEmbedding,
 
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
+ from f5_tts.model.modules import (
21
  TimestepEmbedding,
22
  ConvNeXtV2Block,
23
  ConvPositionEmbedding,
{model β†’ src/f5_tts/model}/cfm.py RENAMED
@@ -18,8 +18,8 @@ from torch.nn.utils.rnn import pad_sequence
18
 
19
  from torchdiffeq import odeint
20
 
21
- from model.modules import MelSpec
22
- from model.utils import (
23
  default,
24
  exists,
25
  list_str_to_idx,
 
18
 
19
  from torchdiffeq import odeint
20
 
21
+ from f5_tts.model.modules import MelSpec
22
+ from f5_tts.model.utils import (
23
  default,
24
  exists,
25
  list_str_to_idx,
{model β†’ src/f5_tts/model}/dataset.py RENAMED
@@ -10,8 +10,8 @@ from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
11
  from torch import nn
12
 
13
- from model.modules import MelSpec
14
- from model.utils import default
15
 
16
 
17
  class HFDataset(Dataset):
 
10
  from datasets import Dataset as Dataset_
11
  from torch import nn
12
 
13
+ from f5_tts.model.modules import MelSpec
14
+ from f5_tts.model.utils import default
15
 
16
 
17
  class HFDataset(Dataset):
{model β†’ src/f5_tts/model}/ecapa_tdnn.py RENAMED
File without changes
{model β†’ src/f5_tts/model}/modules.py RENAMED
File without changes
{model β†’ src/f5_tts/model}/trainer.py RENAMED
@@ -15,9 +15,9 @@ from accelerate.utils import DistributedDataParallelKwargs
15
 
16
  from ema_pytorch import EMA
17
 
18
- from model import CFM
19
- from model.utils import exists, default
20
- from model.dataset import DynamicBatchSampler, collate_fn
21
 
22
 
23
  # trainer
 
15
 
16
  from ema_pytorch import EMA
17
 
18
+ from f5_tts.model import CFM
19
+ from f5_tts.model.utils import exists, default
20
+ from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
 
22
 
23
  # trainer
{model β†’ src/f5_tts/model}/utils.py RENAMED
@@ -4,6 +4,7 @@ import os
4
  import math
5
  import random
6
  import string
 
7
  from tqdm import tqdm
8
  from collections import defaultdict
9
 
@@ -20,8 +21,8 @@ import torchaudio
20
  import jieba
21
  from pypinyin import lazy_pinyin, Style
22
 
23
- from model.ecapa_tdnn import ECAPA_TDNN_SMALL
24
- from model.modules import MelSpec
25
 
26
 
27
  # seed everything
@@ -121,7 +122,8 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
121
  - if use "byte", set to 256 (unicode byte range)
122
  """
123
  if tokenizer in ["pinyin", "char"]:
124
- with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
 
125
  vocab_char_map = {}
126
  for i, char in enumerate(f):
127
  vocab_char_map[char[:-1]] = i
 
4
  import math
5
  import random
6
  import string
7
+ from importlib.resources import files
8
  from tqdm import tqdm
9
  from collections import defaultdict
10
 
 
21
  import jieba
22
  from pypinyin import lazy_pinyin, Style
23
 
24
+ from f5_tts.model.ecapa_tdnn import ECAPA_TDNN_SMALL
25
+ from f5_tts.model.modules import MelSpec
26
 
27
 
28
  # seed everything
 
122
  - if use "byte", set to 256 (unicode byte range)
123
  """
124
  if tokenizer in ["pinyin", "char"]:
125
+ tokenizer_path = os.path.join(files('f5_tts').joinpath('data'), f"{dataset_name}_{tokenizer}/vocab.txt")
126
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
127
  vocab_char_map = {}
128
  for i, char in enumerate(f):
129
  vocab_char_map[char[:-1]] = i
{model β†’ src/f5_tts/model}/utils_infer.py RENAMED
@@ -12,8 +12,8 @@ from pydub import AudioSegment, silence
12
  from transformers import pipeline
13
  from vocos import Vocos
14
 
15
- from model import CFM
16
- from model.utils import (
17
  load_checkpoint,
18
  get_tokenizer,
19
  convert_char_to_pinyin,
 
12
  from transformers import pipeline
13
  from vocos import Vocos
14
 
15
+ from f5_tts.model import CFM
16
+ from f5_tts.model.utils import (
17
  load_checkpoint,
18
  get_tokenizer,
19
  convert_char_to_pinyin,
{scripts β†’ src/f5_tts/scripts}/count_max_epoch.py RENAMED
File without changes
{scripts β†’ src/f5_tts/scripts}/count_params_gflops.py RENAMED
@@ -3,7 +3,7 @@ import os
3
 
4
  sys.path.append(os.getcwd())
5
 
6
- from model import M2_TTS, DiT
7
 
8
  import torch
9
  import thop
 
3
 
4
  sys.path.append(os.getcwd())
5
 
6
+ from f5_tts.model import M2_TTS, DiT
7
 
8
  import torch
9
  import thop
src/f5_tts/scripts/eval_infer_batch.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import time
7
+ import random
8
+ from tqdm import tqdm
9
+ import argparse
10
+ from importlib.resources import files
11
+
12
+ import torch
13
+ import torchaudio
14
+ from accelerate import Accelerator
15
+ from vocos import Vocos
16
+
17
+ from f5_tts.model import CFM, UNetT, DiT
18
+ from f5_tts.model.utils import (
19
+ load_checkpoint,
20
+ get_tokenizer,
21
+ get_seedtts_testset_metainfo,
22
+ get_librispeech_test_clean_metainfo,
23
+ get_inference_prompt,
24
+ )
25
+
26
+ accelerator = Accelerator()
27
+ device = f"cuda:{accelerator.process_index}"
28
+
29
+
30
+ # --------------------- Dataset Settings -------------------- #
31
+
32
+ target_sample_rate = 24000
33
+ n_mel_channels = 100
34
+ hop_length = 256
35
+ target_rms = 0.1
36
+
37
+ tokenizer = "pinyin"
38
+
39
+ def main():
40
+ # ---------------------- infer setting ---------------------- #
41
+
42
+ parser = argparse.ArgumentParser(description="batch inference")
43
+
44
+ parser.add_argument("-s", "--seed", default=None, type=int)
45
+ parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
46
+ parser.add_argument("-n", "--expname", required=True)
47
+ parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
48
+
49
+ parser.add_argument("-nfe", "--nfestep", default=32, type=int)
50
+ parser.add_argument("-o", "--odemethod", default="euler")
51
+ parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
52
+
53
+ parser.add_argument("-t", "--testset", required=True)
54
+
55
+ args = parser.parse_args()
56
+
57
+
58
+ seed = args.seed
59
+ dataset_name = args.dataset
60
+ exp_name = args.expname
61
+ ckpt_step = args.ckptstep
62
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
63
+
64
+ nfe_step = args.nfestep
65
+ ode_method = args.odemethod
66
+ sway_sampling_coef = args.swaysampling
67
+
68
+ testset = args.testset
69
+
70
+
71
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
72
+ cfg_strength = 2.0
73
+ speed = 1.0
74
+ use_truth_duration = False
75
+ no_ref_audio = False
76
+
77
+
78
+ if exp_name == "F5TTS_Base":
79
+ model_cls = DiT
80
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
81
+
82
+ elif exp_name == "E2TTS_Base":
83
+ model_cls = UNetT
84
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
85
+
86
+
87
+ datapath = files('f5_tts').joinpath('data')
88
+
89
+ if testset == "ls_pc_test_clean":
90
+ metalst = os.path.join(datapath,"librispeech_pc_test_clean_cross_sentence.lst")
91
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
92
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
93
+
94
+ elif testset == "seedtts_test_zh":
95
+ metalst = os.path.join(datapath,"seedtts_testset/zh/meta.lst")
96
+ metainfo = get_seedtts_testset_metainfo(metalst)
97
+
98
+ elif testset == "seedtts_test_en":
99
+ metalst = os.path.join(datapath,"seedtts_testset/en/meta.lst")
100
+ metainfo = get_seedtts_testset_metainfo(metalst)
101
+
102
+
103
+ # path to save genereted wavs
104
+ if seed is None:
105
+ seed = random.randint(-10000, 10000)
106
+ output_dir = (
107
+ f"results/{exp_name}_{ckpt_step}/{testset}/"
108
+ f"seed{seed}_{ode_method}_nfe{nfe_step}"
109
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
110
+ f"_cfg{cfg_strength}_speed{speed}"
111
+ f"{'_gt-dur' if use_truth_duration else ''}"
112
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
113
+ )
114
+
115
+
116
+ # -------------------------------------------------#
117
+
118
+ use_ema = True
119
+
120
+ prompts_all = get_inference_prompt(
121
+ metainfo,
122
+ speed=speed,
123
+ tokenizer=tokenizer,
124
+ target_sample_rate=target_sample_rate,
125
+ n_mel_channels=n_mel_channels,
126
+ hop_length=hop_length,
127
+ target_rms=target_rms,
128
+ use_truth_duration=use_truth_duration,
129
+ infer_batch_size=infer_batch_size,
130
+ )
131
+
132
+ # Vocoder model
133
+ local = False
134
+ if local:
135
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
136
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
137
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
138
+ vocos.load_state_dict(state_dict)
139
+ vocos.eval()
140
+ else:
141
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
142
+
143
+ # Tokenizer
144
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
145
+
146
+ # Model
147
+ model = CFM(
148
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
149
+ mel_spec_kwargs=dict(
150
+ target_sample_rate=target_sample_rate,
151
+ n_mel_channels=n_mel_channels,
152
+ hop_length=hop_length,
153
+ ),
154
+ odeint_kwargs=dict(
155
+ method=ode_method,
156
+ ),
157
+ vocab_char_map=vocab_char_map,
158
+ ).to(device)
159
+
160
+ model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
161
+
162
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
163
+ os.makedirs(output_dir)
164
+
165
+ # start batch inference
166
+ accelerator.wait_for_everyone()
167
+ start = time.time()
168
+
169
+ with accelerator.split_between_processes(prompts_all) as prompts:
170
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
171
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
172
+ ref_mels = ref_mels.to(device)
173
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
174
+ total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
175
+
176
+ # Inference
177
+ with torch.inference_mode():
178
+ generated, _ = model.sample(
179
+ cond=ref_mels,
180
+ text=final_text_list,
181
+ duration=total_mel_lens,
182
+ lens=ref_mel_lens,
183
+ steps=nfe_step,
184
+ cfg_strength=cfg_strength,
185
+ sway_sampling_coef=sway_sampling_coef,
186
+ no_ref_audio=no_ref_audio,
187
+ seed=seed,
188
+ )
189
+ # Final result
190
+ for i, gen in enumerate(generated):
191
+ gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
192
+ gen_mel_spec = gen.permute(0, 2, 1)
193
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
194
+ if ref_rms_list[i] < target_rms:
195
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
196
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
197
+
198
+ accelerator.wait_for_everyone()
199
+ if accelerator.is_main_process:
200
+ timediff = time.time() - start
201
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
202
+
203
+ if __name__ == "__main__":
204
+ main()
{scripts β†’ src/f5_tts/scripts}/eval_infer_batch.sh RENAMED
File without changes
{scripts β†’ src/f5_tts/scripts}/eval_librispeech_test_clean.py RENAMED
@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
8
  import multiprocessing as mp
9
  import numpy as np
10
 
11
- from model.utils import (
12
  get_librispeech_test,
13
  run_asr_wer,
14
  run_sim,
 
8
  import multiprocessing as mp
9
  import numpy as np
10
 
11
+ from f5_tts.model.utils import (
12
  get_librispeech_test,
13
  run_asr_wer,
14
  run_sim,
{scripts β†’ src/f5_tts/scripts}/eval_seedtts_testset.py RENAMED
@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
8
  import multiprocessing as mp
9
  import numpy as np
10
 
11
- from model.utils import (
12
  get_seed_tts_test,
13
  run_asr_wer,
14
  run_sim,
 
8
  import multiprocessing as mp
9
  import numpy as np
10
 
11
+ from f5_tts.model.utils import (
12
  get_seed_tts_test,
13
  run_asr_wer,
14
  run_sim,
{scripts β†’ src/f5_tts/scripts}/prepare_csv_wavs.py RENAMED
@@ -13,7 +13,7 @@ import torchaudio
13
  from tqdm import tqdm
14
  from datasets.arrow_writer import ArrowWriter
15
 
16
- from model.utils import (
17
  convert_char_to_pinyin,
18
  )
19
 
 
13
  from tqdm import tqdm
14
  from datasets.arrow_writer import ArrowWriter
15
 
16
+ from f5_tts.model.utils import (
17
  convert_char_to_pinyin,
18
  )
19
 
{scripts β†’ src/f5_tts/scripts}/prepare_emilia.py RENAMED
@@ -16,7 +16,7 @@ from concurrent.futures import ProcessPoolExecutor
16
 
17
  from datasets.arrow_writer import ArrowWriter
18
 
19
- from model.utils import (
20
  repetition_found,
21
  convert_char_to_pinyin,
22
  )
 
16
 
17
  from datasets.arrow_writer import ArrowWriter
18
 
19
+ from f5_tts.model.utils import (
20
  repetition_found,
21
  convert_char_to_pinyin,
22
  )
{scripts β†’ src/f5_tts/scripts}/prepare_wenetspeech4tts.py RENAMED
@@ -13,7 +13,7 @@ from concurrent.futures import ProcessPoolExecutor
13
  import torchaudio
14
  from datasets import Dataset
15
 
16
- from model.utils import convert_char_to_pinyin
17
 
18
 
19
  def deal_with_sub_path_files(dataset_path, sub_path):
 
13
  import torchaudio
14
  from datasets import Dataset
15
 
16
+ from f5_tts.model.utils import convert_char_to_pinyin
17
 
18
 
19
  def deal_with_sub_path_files(dataset_path, sub_path):
speech_edit.py β†’ src/f5_tts/speech_edit.py RENAMED
@@ -5,8 +5,8 @@ import torch.nn.functional as F
5
  import torchaudio
6
  from vocos import Vocos
7
 
8
- from model import CFM, UNetT, DiT
9
- from model.utils import (
10
  load_checkpoint,
11
  get_tokenizer,
12
  convert_char_to_pinyin,
 
5
  import torchaudio
6
  from vocos import Vocos
7
 
8
+ from f5_tts.model import CFM, UNetT, DiT
9
+ from f5_tts.model.utils import (
10
  load_checkpoint,
11
  get_tokenizer,
12
  convert_char_to_pinyin,
train.py β†’ src/f5_tts/train.py RENAMED
@@ -1,6 +1,6 @@
1
- from model import CFM, UNetT, DiT, Trainer
2
- from model.utils import get_tokenizer
3
- from model.dataset import load_dataset
4
 
5
 
6
  # -------------------------- Dataset Settings --------------------------- #
 
1
+ from f5_tts.model import CFM, UNetT, DiT, Trainer
2
+ from f5_tts.model.utils import get_tokenizer
3
+ from f5_tts.model.dataset import load_dataset
4
 
5
 
6
  # -------------------------- Dataset Settings --------------------------- #