convert to pkg, reorganize repo (#228)
Browse files* group files in f5_tts directory
* add setup.py
* use global imports
* simplify demo
* add install directions for library mode
* fix old huggingface_hub version constraint
* move finetune to package
* change imports to f5_tts.model
* bump version
* fix bad merge
* Update inference-cli.py
* fix HF space
* reformat
* fix utils.py vocab.txt import
* fix format
* adapt README for f5_tts package structure
* simplify app.py
* add gradio.Dockerfile and workflow
* refactored for pyproject.toml
* refactored for pyproject.toml
* added in reference to packaged files
* use fork for testing docker image
* added in reference to packaged files
* minor tweaks
* fixed inference-cli.toml path
* fixed inference-cli.toml path
* fixed inference-cli.toml path
* fixed inference-cli.toml path
* refactor eval_infer_batch.py
* fix typo
* added eval_infer_batch to scripts
---------
Co-authored-by: Roberts Slisans <[email protected]>
Co-authored-by: Adam Kessel <[email protected]>
Co-authored-by: Roberts Slisans <[email protected]>
- .github/workflows/publish-docker-image.yaml +61 -0
- README.md +42 -6
- app.py +3 -0
- gradio.Dockerfile +27 -0
- model/__init__.py +0 -10
- pyproject.toml +52 -0
- scripts/eval_infer_batch.py +0 -198
- api.py β src/f5_tts/api.py +4 -4
- {data β src/f5_tts/data}/Emilia_ZH_EN_pinyin/vocab.txt +0 -0
- inference-cli.toml β src/f5_tts/data/inference-cli.toml +0 -0
- {data β src/f5_tts/data}/librispeech_pc_test_clean_cross_sentence.lst +0 -0
- finetune-cli.py β src/f5_tts/finetune_cli.py +3 -3
- finetune_gradio.py β src/f5_tts/finetune_gradio.py +2 -2
- gradio_app.py β src/f5_tts/gradio_app.py +3 -3
- inference-cli.py β src/f5_tts/inference_cli.py +10 -5
- src/f5_tts/model/__init__.py +10 -0
- {model β src/f5_tts/model}/backbones/README.md +0 -0
- {model β src/f5_tts/model}/backbones/dit.py +1 -1
- {model β src/f5_tts/model}/backbones/mmdit.py +1 -1
- {model β src/f5_tts/model}/backbones/unett.py +1 -1
- {model β src/f5_tts/model}/cfm.py +2 -2
- {model β src/f5_tts/model}/dataset.py +2 -2
- {model β src/f5_tts/model}/ecapa_tdnn.py +0 -0
- {model β src/f5_tts/model}/modules.py +0 -0
- {model β src/f5_tts/model}/trainer.py +3 -3
- {model β src/f5_tts/model}/utils.py +5 -3
- {model β src/f5_tts/model}/utils_infer.py +2 -2
- {scripts β src/f5_tts/scripts}/count_max_epoch.py +0 -0
- {scripts β src/f5_tts/scripts}/count_params_gflops.py +1 -1
- src/f5_tts/scripts/eval_infer_batch.py +204 -0
- {scripts β src/f5_tts/scripts}/eval_infer_batch.sh +0 -0
- {scripts β src/f5_tts/scripts}/eval_librispeech_test_clean.py +1 -1
- {scripts β src/f5_tts/scripts}/eval_seedtts_testset.py +1 -1
- {scripts β src/f5_tts/scripts}/prepare_csv_wavs.py +1 -1
- {scripts β src/f5_tts/scripts}/prepare_emilia.py +1 -1
- {scripts β src/f5_tts/scripts}/prepare_wenetspeech4tts.py +1 -1
- speech_edit.py β src/f5_tts/speech_edit.py +2 -2
- train.py β src/f5_tts/train.py +3 -3
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Create and publish a Docker image
|
2 |
+
|
3 |
+
# Configures this workflow to run every time a change is pushed to the branch called `release`.
|
4 |
+
on:
|
5 |
+
push:
|
6 |
+
branches: ['main']
|
7 |
+
|
8 |
+
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
9 |
+
env:
|
10 |
+
REGISTRY: ghcr.io
|
11 |
+
IMAGE_NAME: ${{ github.repository }}
|
12 |
+
|
13 |
+
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
|
14 |
+
jobs:
|
15 |
+
build-and-push-image:
|
16 |
+
runs-on: ubuntu-latest
|
17 |
+
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
|
18 |
+
permissions:
|
19 |
+
contents: read
|
20 |
+
packages: write
|
21 |
+
#
|
22 |
+
steps:
|
23 |
+
- name: Checkout repository
|
24 |
+
uses: actions/checkout@v4
|
25 |
+
- name: Free Up GitHub Actions Ubuntu Runner Disk Space π§
|
26 |
+
uses: jlumbroso/free-disk-space@main
|
27 |
+
with:
|
28 |
+
# This might remove tools that are actually needed, if set to "true" but frees about 6 GB
|
29 |
+
tool-cache: false
|
30 |
+
|
31 |
+
# All of these default to true, but feel free to set to "false" if necessary for your workflow
|
32 |
+
android: true
|
33 |
+
dotnet: true
|
34 |
+
haskell: true
|
35 |
+
large-packages: false
|
36 |
+
swap-storage: false
|
37 |
+
docker-images: false
|
38 |
+
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
|
39 |
+
- name: Log in to the Container registry
|
40 |
+
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
41 |
+
with:
|
42 |
+
registry: ${{ env.REGISTRY }}
|
43 |
+
username: ${{ github.actor }}
|
44 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
45 |
+
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
|
46 |
+
- name: Extract metadata (tags, labels) for Docker
|
47 |
+
id: meta
|
48 |
+
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
|
49 |
+
with:
|
50 |
+
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
51 |
+
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
|
52 |
+
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
|
53 |
+
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
|
54 |
+
- name: Build and push Docker image
|
55 |
+
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
56 |
+
with:
|
57 |
+
context: .
|
58 |
+
file: ./gradio.Dockerfile
|
59 |
+
push: true
|
60 |
+
tags: ${{ steps.meta.outputs.tags }}
|
61 |
+
labels: ${{ steps.meta.outputs.labels }}
|
@@ -63,11 +63,35 @@ pre-commit run --all-files
|
|
63 |
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
64 |
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
## Prepare Dataset
|
67 |
|
68 |
-
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
|
69 |
|
70 |
```bash
|
|
|
|
|
|
|
71 |
# prepare custom dataset up to your need
|
72 |
# download corresponding dataset first, and fill in the path in scripts
|
73 |
|
@@ -83,6 +107,9 @@ python scripts/prepare_wenetspeech4tts.py
|
|
83 |
Once your datasets are prepared, you can start the training process.
|
84 |
|
85 |
```bash
|
|
|
|
|
|
|
86 |
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
87 |
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
88 |
accelerate config
|
@@ -90,7 +117,7 @@ accelerate launch train.py
|
|
90 |
```
|
91 |
An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
92 |
|
93 |
-
Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
94 |
|
95 |
### Wandb Logging
|
96 |
|
@@ -136,6 +163,9 @@ for change model use `--ckpt_file` to specify the model you want to load,
|
|
136 |
for change vocab.txt use `--vocab_file` to provide your vocab.txt file.
|
137 |
|
138 |
```bash
|
|
|
|
|
|
|
139 |
python inference-cli.py \
|
140 |
--model "F5-TTS" \
|
141 |
--ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
|
@@ -161,19 +191,19 @@ Currently supported features:
|
|
161 |
You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
|
162 |
|
163 |
```bash
|
164 |
-
python gradio_app.py
|
165 |
```
|
166 |
|
167 |
You can specify the port/host:
|
168 |
|
169 |
```bash
|
170 |
-
python gradio_app.py --port 7860 --host 0.0.0.0
|
171 |
```
|
172 |
|
173 |
Or launch a share link:
|
174 |
|
175 |
```bash
|
176 |
-
python gradio_app.py --share
|
177 |
```
|
178 |
|
179 |
### Speech Editing
|
@@ -181,7 +211,7 @@ python gradio_app.py --share
|
|
181 |
To test speech editing capabilities, use the following command.
|
182 |
|
183 |
```bash
|
184 |
-
python speech_edit.py
|
185 |
```
|
186 |
|
187 |
## Evaluation
|
@@ -199,6 +229,9 @@ python speech_edit.py
|
|
199 |
To run batch inference for evaluations, execute the following commands:
|
200 |
|
201 |
```bash
|
|
|
|
|
|
|
202 |
# batch inference for evaluations
|
203 |
accelerate config # if not set before
|
204 |
bash scripts/eval_infer_batch.sh
|
@@ -234,6 +267,9 @@ pip install faster-whisper==0.10.1
|
|
234 |
|
235 |
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
236 |
```bash
|
|
|
|
|
|
|
237 |
# Evaluation for Seed-TTS test set
|
238 |
python scripts/eval_seedtts_testset.py
|
239 |
|
|
|
63 |
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
64 |
|
65 |
|
66 |
+
### As a pip package
|
67 |
+
|
68 |
+
```bash
|
69 |
+
pip install git+https://github.com/SWivid/F5-TTS.git
|
70 |
+
```
|
71 |
+
|
72 |
+
```python
|
73 |
+
import gradio as gr
|
74 |
+
from f5_tts.gradio_app import app
|
75 |
+
|
76 |
+
with gr.Blocks() as main_app:
|
77 |
+
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
|
78 |
+
|
79 |
+
# ... other Gradio components
|
80 |
+
|
81 |
+
app.render()
|
82 |
+
|
83 |
+
main_app.launch()
|
84 |
+
|
85 |
+
```
|
86 |
+
|
87 |
## Prepare Dataset
|
88 |
|
89 |
+
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `f5_tts/model/dataset.py`.
|
90 |
|
91 |
```bash
|
92 |
+
# switch to the main directory
|
93 |
+
cd f5_tts
|
94 |
+
|
95 |
# prepare custom dataset up to your need
|
96 |
# download corresponding dataset first, and fill in the path in scripts
|
97 |
|
|
|
107 |
Once your datasets are prepared, you can start the training process.
|
108 |
|
109 |
```bash
|
110 |
+
# switch to the main directory
|
111 |
+
cd f5_tts
|
112 |
+
|
113 |
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
114 |
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
115 |
accelerate config
|
|
|
117 |
```
|
118 |
An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
119 |
|
120 |
+
Gradio UI finetuning with `f5_tts/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
121 |
|
122 |
### Wandb Logging
|
123 |
|
|
|
163 |
for change vocab.txt use `--vocab_file` to provide your vocab.txt file.
|
164 |
|
165 |
```bash
|
166 |
+
# switch to the main directory
|
167 |
+
cd f5_tts
|
168 |
+
|
169 |
python inference-cli.py \
|
170 |
--model "F5-TTS" \
|
171 |
--ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
|
|
|
191 |
You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
|
192 |
|
193 |
```bash
|
194 |
+
python f5_tts/gradio_app.py
|
195 |
```
|
196 |
|
197 |
You can specify the port/host:
|
198 |
|
199 |
```bash
|
200 |
+
python f5_tts/gradio_app.py --port 7860 --host 0.0.0.0
|
201 |
```
|
202 |
|
203 |
Or launch a share link:
|
204 |
|
205 |
```bash
|
206 |
+
python f5_tts/gradio_app.py --share
|
207 |
```
|
208 |
|
209 |
### Speech Editing
|
|
|
211 |
To test speech editing capabilities, use the following command.
|
212 |
|
213 |
```bash
|
214 |
+
python f5_tts/speech_edit.py
|
215 |
```
|
216 |
|
217 |
## Evaluation
|
|
|
229 |
To run batch inference for evaluations, execute the following commands:
|
230 |
|
231 |
```bash
|
232 |
+
# switch to the main directory
|
233 |
+
cd f5_tts
|
234 |
+
|
235 |
# batch inference for evaluations
|
236 |
accelerate config # if not set before
|
237 |
bash scripts/eval_infer_batch.sh
|
|
|
267 |
|
268 |
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
269 |
```bash
|
270 |
+
# switch to the main directory
|
271 |
+
cd f5_tts
|
272 |
+
|
273 |
# Evaluation for Seed-TTS test set
|
274 |
python scripts/eval_seedtts_testset.py
|
275 |
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from f5_tts.gradio_app import app
|
2 |
+
|
3 |
+
app.queue().launch()
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
|
2 |
+
|
3 |
+
USER root
|
4 |
+
|
5 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
6 |
+
|
7 |
+
LABEL github_repo="https://github.com/rsxdalv/F5-TTS"
|
8 |
+
|
9 |
+
RUN set -x \
|
10 |
+
&& apt-get update \
|
11 |
+
&& apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
|
12 |
+
&& apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
|
13 |
+
&& rm -rf /var/lib/apt/lists/* \
|
14 |
+
&& apt-get clean
|
15 |
+
|
16 |
+
WORKDIR /workspace
|
17 |
+
|
18 |
+
RUN git clone https://github.com/rsxdalv/F5-TTS.git \
|
19 |
+
&& cd F5-TTS \
|
20 |
+
&& pip install --no-cache-dir -r requirements.txt
|
21 |
+
|
22 |
+
ENV SHELL=/bin/bash
|
23 |
+
|
24 |
+
WORKDIR /workspace/F5-TTS/f5_tts
|
25 |
+
|
26 |
+
EXPOSE 7860
|
27 |
+
CMD python gradio_app.py
|
@@ -1,10 +0,0 @@
|
|
1 |
-
from model.cfm import CFM
|
2 |
-
|
3 |
-
from model.backbones.unett import UNetT
|
4 |
-
from model.backbones.dit import DiT
|
5 |
-
from model.backbones.mmdit import MMDiT
|
6 |
-
|
7 |
-
from model.trainer import Trainer
|
8 |
-
|
9 |
-
|
10 |
-
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "f5-tts"
|
7 |
+
dynamic = ["version"]
|
8 |
+
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
+
readme = "README.md"
|
10 |
+
classifiers = [
|
11 |
+
"License :: OSI Approved :: MIT License",
|
12 |
+
"Operating System :: OS Independent",
|
13 |
+
"Programming Language :: Python :: 3",
|
14 |
+
]
|
15 |
+
dependencies = [
|
16 |
+
"accelerate>=0.33.0",
|
17 |
+
"cached_path @ git+https://github.com/rsxdalv/cached_path@main",
|
18 |
+
"click",
|
19 |
+
"datasets",
|
20 |
+
"einops>=0.8.0",
|
21 |
+
"einx>=0.3.0",
|
22 |
+
"ema_pytorch>=0.5.2",
|
23 |
+
"gradio",
|
24 |
+
"jieba",
|
25 |
+
"librosa",
|
26 |
+
"matplotlib",
|
27 |
+
"numpy<=1.26.4",
|
28 |
+
"pydub",
|
29 |
+
"pypinyin",
|
30 |
+
"safetensors",
|
31 |
+
"soundfile",
|
32 |
+
"tomli",
|
33 |
+
"torch>=2.0.0",
|
34 |
+
"torchaudio>=2.0.0",
|
35 |
+
"torchdiffeq",
|
36 |
+
"tqdm>=4.65.0",
|
37 |
+
"transformers",
|
38 |
+
"vocos",
|
39 |
+
"wandb",
|
40 |
+
"x_transformers>=1.31.14",
|
41 |
+
]
|
42 |
+
|
43 |
+
[[project.authors]]
|
44 |
+
name = "Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen"
|
45 |
+
|
46 |
+
[project.urls]
|
47 |
+
Homepage = "https://github.com/SWivid/F5-TTS"
|
48 |
+
|
49 |
+
[project.scripts]
|
50 |
+
"finetune-cli" = "f5_tts.finetune_cli:main"
|
51 |
+
"inference-cli" = "f5_tts.inference_cli:main"
|
52 |
+
"eval_infer_batch" = "f5_tts.scripts.eval_infer_batch:main"
|
@@ -1,198 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
|
4 |
-
sys.path.append(os.getcwd())
|
5 |
-
|
6 |
-
import time
|
7 |
-
import random
|
8 |
-
from tqdm import tqdm
|
9 |
-
import argparse
|
10 |
-
|
11 |
-
import torch
|
12 |
-
import torchaudio
|
13 |
-
from accelerate import Accelerator
|
14 |
-
from vocos import Vocos
|
15 |
-
|
16 |
-
from model import CFM, UNetT, DiT
|
17 |
-
from model.utils import (
|
18 |
-
load_checkpoint,
|
19 |
-
get_tokenizer,
|
20 |
-
get_seedtts_testset_metainfo,
|
21 |
-
get_librispeech_test_clean_metainfo,
|
22 |
-
get_inference_prompt,
|
23 |
-
)
|
24 |
-
|
25 |
-
accelerator = Accelerator()
|
26 |
-
device = f"cuda:{accelerator.process_index}"
|
27 |
-
|
28 |
-
|
29 |
-
# --------------------- Dataset Settings -------------------- #
|
30 |
-
|
31 |
-
target_sample_rate = 24000
|
32 |
-
n_mel_channels = 100
|
33 |
-
hop_length = 256
|
34 |
-
target_rms = 0.1
|
35 |
-
|
36 |
-
tokenizer = "pinyin"
|
37 |
-
|
38 |
-
|
39 |
-
# ---------------------- infer setting ---------------------- #
|
40 |
-
|
41 |
-
parser = argparse.ArgumentParser(description="batch inference")
|
42 |
-
|
43 |
-
parser.add_argument("-s", "--seed", default=None, type=int)
|
44 |
-
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
45 |
-
parser.add_argument("-n", "--expname", required=True)
|
46 |
-
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
47 |
-
|
48 |
-
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
49 |
-
parser.add_argument("-o", "--odemethod", default="euler")
|
50 |
-
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
51 |
-
|
52 |
-
parser.add_argument("-t", "--testset", required=True)
|
53 |
-
|
54 |
-
args = parser.parse_args()
|
55 |
-
|
56 |
-
|
57 |
-
seed = args.seed
|
58 |
-
dataset_name = args.dataset
|
59 |
-
exp_name = args.expname
|
60 |
-
ckpt_step = args.ckptstep
|
61 |
-
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
62 |
-
|
63 |
-
nfe_step = args.nfestep
|
64 |
-
ode_method = args.odemethod
|
65 |
-
sway_sampling_coef = args.swaysampling
|
66 |
-
|
67 |
-
testset = args.testset
|
68 |
-
|
69 |
-
|
70 |
-
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
71 |
-
cfg_strength = 2.0
|
72 |
-
speed = 1.0
|
73 |
-
use_truth_duration = False
|
74 |
-
no_ref_audio = False
|
75 |
-
|
76 |
-
|
77 |
-
if exp_name == "F5TTS_Base":
|
78 |
-
model_cls = DiT
|
79 |
-
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
80 |
-
|
81 |
-
elif exp_name == "E2TTS_Base":
|
82 |
-
model_cls = UNetT
|
83 |
-
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
84 |
-
|
85 |
-
|
86 |
-
if testset == "ls_pc_test_clean":
|
87 |
-
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
88 |
-
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
89 |
-
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
90 |
-
|
91 |
-
elif testset == "seedtts_test_zh":
|
92 |
-
metalst = "data/seedtts_testset/zh/meta.lst"
|
93 |
-
metainfo = get_seedtts_testset_metainfo(metalst)
|
94 |
-
|
95 |
-
elif testset == "seedtts_test_en":
|
96 |
-
metalst = "data/seedtts_testset/en/meta.lst"
|
97 |
-
metainfo = get_seedtts_testset_metainfo(metalst)
|
98 |
-
|
99 |
-
|
100 |
-
# path to save genereted wavs
|
101 |
-
if seed is None:
|
102 |
-
seed = random.randint(-10000, 10000)
|
103 |
-
output_dir = (
|
104 |
-
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
105 |
-
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
106 |
-
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
107 |
-
f"_cfg{cfg_strength}_speed{speed}"
|
108 |
-
f"{'_gt-dur' if use_truth_duration else ''}"
|
109 |
-
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
110 |
-
)
|
111 |
-
|
112 |
-
|
113 |
-
# -------------------------------------------------#
|
114 |
-
|
115 |
-
use_ema = True
|
116 |
-
|
117 |
-
prompts_all = get_inference_prompt(
|
118 |
-
metainfo,
|
119 |
-
speed=speed,
|
120 |
-
tokenizer=tokenizer,
|
121 |
-
target_sample_rate=target_sample_rate,
|
122 |
-
n_mel_channels=n_mel_channels,
|
123 |
-
hop_length=hop_length,
|
124 |
-
target_rms=target_rms,
|
125 |
-
use_truth_duration=use_truth_duration,
|
126 |
-
infer_batch_size=infer_batch_size,
|
127 |
-
)
|
128 |
-
|
129 |
-
# Vocoder model
|
130 |
-
local = False
|
131 |
-
if local:
|
132 |
-
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
133 |
-
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
134 |
-
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
135 |
-
vocos.load_state_dict(state_dict)
|
136 |
-
vocos.eval()
|
137 |
-
else:
|
138 |
-
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
139 |
-
|
140 |
-
# Tokenizer
|
141 |
-
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
142 |
-
|
143 |
-
# Model
|
144 |
-
model = CFM(
|
145 |
-
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
146 |
-
mel_spec_kwargs=dict(
|
147 |
-
target_sample_rate=target_sample_rate,
|
148 |
-
n_mel_channels=n_mel_channels,
|
149 |
-
hop_length=hop_length,
|
150 |
-
),
|
151 |
-
odeint_kwargs=dict(
|
152 |
-
method=ode_method,
|
153 |
-
),
|
154 |
-
vocab_char_map=vocab_char_map,
|
155 |
-
).to(device)
|
156 |
-
|
157 |
-
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
158 |
-
|
159 |
-
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
160 |
-
os.makedirs(output_dir)
|
161 |
-
|
162 |
-
# start batch inference
|
163 |
-
accelerator.wait_for_everyone()
|
164 |
-
start = time.time()
|
165 |
-
|
166 |
-
with accelerator.split_between_processes(prompts_all) as prompts:
|
167 |
-
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
168 |
-
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
169 |
-
ref_mels = ref_mels.to(device)
|
170 |
-
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
171 |
-
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
172 |
-
|
173 |
-
# Inference
|
174 |
-
with torch.inference_mode():
|
175 |
-
generated, _ = model.sample(
|
176 |
-
cond=ref_mels,
|
177 |
-
text=final_text_list,
|
178 |
-
duration=total_mel_lens,
|
179 |
-
lens=ref_mel_lens,
|
180 |
-
steps=nfe_step,
|
181 |
-
cfg_strength=cfg_strength,
|
182 |
-
sway_sampling_coef=sway_sampling_coef,
|
183 |
-
no_ref_audio=no_ref_audio,
|
184 |
-
seed=seed,
|
185 |
-
)
|
186 |
-
# Final result
|
187 |
-
for i, gen in enumerate(generated):
|
188 |
-
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
189 |
-
gen_mel_spec = gen.permute(0, 2, 1)
|
190 |
-
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
191 |
-
if ref_rms_list[i] < target_rms:
|
192 |
-
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
193 |
-
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
194 |
-
|
195 |
-
accelerator.wait_for_everyone()
|
196 |
-
if accelerator.is_main_process:
|
197 |
-
timediff = time.time() - start
|
198 |
-
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -3,11 +3,11 @@ import torch
|
|
3 |
import tqdm
|
4 |
from cached_path import cached_path
|
5 |
|
6 |
-
from model import DiT, UNetT
|
7 |
-
from model.utils import save_spectrogram
|
8 |
|
9 |
-
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
|
10 |
-
from model.utils import seed_everything
|
11 |
import random
|
12 |
import sys
|
13 |
|
|
|
3 |
import tqdm
|
4 |
from cached_path import cached_path
|
5 |
|
6 |
+
from f5_tts.model import DiT, UNetT
|
7 |
+
from f5_tts.model.utils import save_spectrogram
|
8 |
|
9 |
+
from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
|
10 |
+
from f5_tts.model.utils import seed_everything
|
11 |
import random
|
12 |
import sys
|
13 |
|
File without changes
|
File without changes
|
File without changes
|
@@ -1,7 +1,7 @@
|
|
1 |
import argparse
|
2 |
-
from model import CFM, UNetT, DiT, Trainer
|
3 |
-
from model.utils import get_tokenizer
|
4 |
-
from model.dataset import load_dataset
|
5 |
from cached_path import cached_path
|
6 |
import shutil
|
7 |
import os
|
|
|
1 |
import argparse
|
2 |
+
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
3 |
+
from f5_tts.model.utils import get_tokenizer
|
4 |
+
from f5_tts.model.dataset import load_dataset
|
5 |
from cached_path import cached_path
|
6 |
import shutil
|
7 |
import os
|
@@ -17,14 +17,14 @@ import shutil
|
|
17 |
import time
|
18 |
|
19 |
import json
|
20 |
-
from model.utils import convert_char_to_pinyin
|
21 |
import signal
|
22 |
import psutil
|
23 |
import platform
|
24 |
import subprocess
|
25 |
from datasets.arrow_writer import ArrowWriter
|
26 |
from datasets import Dataset as Dataset_
|
27 |
-
from api import F5TTS
|
28 |
|
29 |
|
30 |
training_process = None
|
|
|
17 |
import time
|
18 |
|
19 |
import json
|
20 |
+
from f5_tts.model.utils import convert_char_to_pinyin
|
21 |
import signal
|
22 |
import psutil
|
23 |
import platform
|
24 |
import subprocess
|
25 |
from datasets.arrow_writer import ArrowWriter
|
26 |
from datasets import Dataset as Dataset_
|
27 |
+
from f5_tts.api import F5TTS
|
28 |
|
29 |
|
30 |
training_process = None
|
@@ -27,11 +27,11 @@ def gpu_decorator(func):
|
|
27 |
return func
|
28 |
|
29 |
|
30 |
-
from model import DiT, UNetT
|
31 |
-
from model.utils import (
|
32 |
save_spectrogram,
|
33 |
)
|
34 |
-
from model.utils_infer import (
|
35 |
load_vocoder,
|
36 |
load_model,
|
37 |
preprocess_ref_audio_text,
|
|
|
27 |
return func
|
28 |
|
29 |
|
30 |
+
from f5_tts.model import DiT, UNetT
|
31 |
+
from f5_tts.model.utils import (
|
32 |
save_spectrogram,
|
33 |
)
|
34 |
+
from f5_tts.model.utils_infer import (
|
35 |
load_vocoder,
|
36 |
load_model,
|
37 |
preprocess_ref_audio_text,
|
@@ -1,15 +1,17 @@
|
|
1 |
import argparse
|
2 |
import codecs
|
3 |
import re
|
|
|
4 |
from pathlib import Path
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import soundfile as sf
|
8 |
import tomli
|
9 |
from cached_path import cached_path
|
10 |
|
11 |
-
from model import DiT, UNetT
|
12 |
-
from model.utils_infer import (
|
13 |
load_vocoder,
|
14 |
load_model,
|
15 |
preprocess_ref_audio_text,
|
@@ -26,8 +28,8 @@ parser = argparse.ArgumentParser(
|
|
26 |
parser.add_argument(
|
27 |
"-c",
|
28 |
"--config",
|
29 |
-
help="Configuration file. Default=cli
|
30 |
-
default=
|
31 |
)
|
32 |
parser.add_argument(
|
33 |
"-m",
|
@@ -166,5 +168,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
|
166 |
remove_silence_for_generated_wav(f.name)
|
167 |
print(f.name)
|
168 |
|
|
|
|
|
169 |
|
170 |
-
|
|
|
|
1 |
import argparse
|
2 |
import codecs
|
3 |
import re
|
4 |
+
import os
|
5 |
from pathlib import Path
|
6 |
+
from importlib.resources import files
|
7 |
|
8 |
import numpy as np
|
9 |
import soundfile as sf
|
10 |
import tomli
|
11 |
from cached_path import cached_path
|
12 |
|
13 |
+
from f5_tts.model import DiT, UNetT
|
14 |
+
from f5_tts.model.utils_infer import (
|
15 |
load_vocoder,
|
16 |
load_model,
|
17 |
preprocess_ref_audio_text,
|
|
|
28 |
parser.add_argument(
|
29 |
"-c",
|
30 |
"--config",
|
31 |
+
help="Configuration file. Default=inference-cli.toml",
|
32 |
+
default=os.path.join(files('f5_tts').joinpath('data'), 'inference-cli.toml')
|
33 |
)
|
34 |
parser.add_argument(
|
35 |
"-m",
|
|
|
168 |
remove_silence_for_generated_wav(f.name)
|
169 |
print(f.name)
|
170 |
|
171 |
+
def main():
|
172 |
+
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
|
173 |
|
174 |
+
if __name__ == "__main__":
|
175 |
+
main()
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from f5_tts.model.cfm import CFM
|
2 |
+
|
3 |
+
from f5_tts.model.backbones.unett import UNetT
|
4 |
+
from f5_tts.model.backbones.dit import DiT
|
5 |
+
from f5_tts.model.backbones.mmdit import MMDiT
|
6 |
+
|
7 |
+
from f5_tts.model.trainer import Trainer
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
File without changes
|
@@ -15,7 +15,7 @@ import torch.nn.functional as F
|
|
15 |
|
16 |
from x_transformers.x_transformers import RotaryEmbedding
|
17 |
|
18 |
-
from model.modules import (
|
19 |
TimestepEmbedding,
|
20 |
ConvNeXtV2Block,
|
21 |
ConvPositionEmbedding,
|
|
|
15 |
|
16 |
from x_transformers.x_transformers import RotaryEmbedding
|
17 |
|
18 |
+
from f5_tts.model.modules import (
|
19 |
TimestepEmbedding,
|
20 |
ConvNeXtV2Block,
|
21 |
ConvPositionEmbedding,
|
@@ -14,7 +14,7 @@ from torch import nn
|
|
14 |
|
15 |
from x_transformers.x_transformers import RotaryEmbedding
|
16 |
|
17 |
-
from model.modules import (
|
18 |
TimestepEmbedding,
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
|
|
14 |
|
15 |
from x_transformers.x_transformers import RotaryEmbedding
|
16 |
|
17 |
+
from f5_tts.model.modules import (
|
18 |
TimestepEmbedding,
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
@@ -17,7 +17,7 @@ import torch.nn.functional as F
|
|
17 |
from x_transformers import RMSNorm
|
18 |
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
|
20 |
-
from model.modules import (
|
21 |
TimestepEmbedding,
|
22 |
ConvNeXtV2Block,
|
23 |
ConvPositionEmbedding,
|
|
|
17 |
from x_transformers import RMSNorm
|
18 |
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
|
20 |
+
from f5_tts.model.modules import (
|
21 |
TimestepEmbedding,
|
22 |
ConvNeXtV2Block,
|
23 |
ConvPositionEmbedding,
|
@@ -18,8 +18,8 @@ from torch.nn.utils.rnn import pad_sequence
|
|
18 |
|
19 |
from torchdiffeq import odeint
|
20 |
|
21 |
-
from model.modules import MelSpec
|
22 |
-
from model.utils import (
|
23 |
default,
|
24 |
exists,
|
25 |
list_str_to_idx,
|
|
|
18 |
|
19 |
from torchdiffeq import odeint
|
20 |
|
21 |
+
from f5_tts.model.modules import MelSpec
|
22 |
+
from f5_tts.model.utils import (
|
23 |
default,
|
24 |
exists,
|
25 |
list_str_to_idx,
|
@@ -10,8 +10,8 @@ from datasets import load_from_disk
|
|
10 |
from datasets import Dataset as Dataset_
|
11 |
from torch import nn
|
12 |
|
13 |
-
from model.modules import MelSpec
|
14 |
-
from model.utils import default
|
15 |
|
16 |
|
17 |
class HFDataset(Dataset):
|
|
|
10 |
from datasets import Dataset as Dataset_
|
11 |
from torch import nn
|
12 |
|
13 |
+
from f5_tts.model.modules import MelSpec
|
14 |
+
from f5_tts.model.utils import default
|
15 |
|
16 |
|
17 |
class HFDataset(Dataset):
|
File without changes
|
File without changes
|
@@ -15,9 +15,9 @@ from accelerate.utils import DistributedDataParallelKwargs
|
|
15 |
|
16 |
from ema_pytorch import EMA
|
17 |
|
18 |
-
from model import CFM
|
19 |
-
from model.utils import exists, default
|
20 |
-
from model.dataset import DynamicBatchSampler, collate_fn
|
21 |
|
22 |
|
23 |
# trainer
|
|
|
15 |
|
16 |
from ema_pytorch import EMA
|
17 |
|
18 |
+
from f5_tts.model import CFM
|
19 |
+
from f5_tts.model.utils import exists, default
|
20 |
+
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
21 |
|
22 |
|
23 |
# trainer
|
@@ -4,6 +4,7 @@ import os
|
|
4 |
import math
|
5 |
import random
|
6 |
import string
|
|
|
7 |
from tqdm import tqdm
|
8 |
from collections import defaultdict
|
9 |
|
@@ -20,8 +21,8 @@ import torchaudio
|
|
20 |
import jieba
|
21 |
from pypinyin import lazy_pinyin, Style
|
22 |
|
23 |
-
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
24 |
-
from model.modules import MelSpec
|
25 |
|
26 |
|
27 |
# seed everything
|
@@ -121,7 +122,8 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
121 |
- if use "byte", set to 256 (unicode byte range)
|
122 |
"""
|
123 |
if tokenizer in ["pinyin", "char"]:
|
124 |
-
|
|
|
125 |
vocab_char_map = {}
|
126 |
for i, char in enumerate(f):
|
127 |
vocab_char_map[char[:-1]] = i
|
|
|
4 |
import math
|
5 |
import random
|
6 |
import string
|
7 |
+
from importlib.resources import files
|
8 |
from tqdm import tqdm
|
9 |
from collections import defaultdict
|
10 |
|
|
|
21 |
import jieba
|
22 |
from pypinyin import lazy_pinyin, Style
|
23 |
|
24 |
+
from f5_tts.model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
25 |
+
from f5_tts.model.modules import MelSpec
|
26 |
|
27 |
|
28 |
# seed everything
|
|
|
122 |
- if use "byte", set to 256 (unicode byte range)
|
123 |
"""
|
124 |
if tokenizer in ["pinyin", "char"]:
|
125 |
+
tokenizer_path = os.path.join(files('f5_tts').joinpath('data'), f"{dataset_name}_{tokenizer}/vocab.txt")
|
126 |
+
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
127 |
vocab_char_map = {}
|
128 |
for i, char in enumerate(f):
|
129 |
vocab_char_map[char[:-1]] = i
|
@@ -12,8 +12,8 @@ from pydub import AudioSegment, silence
|
|
12 |
from transformers import pipeline
|
13 |
from vocos import Vocos
|
14 |
|
15 |
-
from model import CFM
|
16 |
-
from model.utils import (
|
17 |
load_checkpoint,
|
18 |
get_tokenizer,
|
19 |
convert_char_to_pinyin,
|
|
|
12 |
from transformers import pipeline
|
13 |
from vocos import Vocos
|
14 |
|
15 |
+
from f5_tts.model import CFM
|
16 |
+
from f5_tts.model.utils import (
|
17 |
load_checkpoint,
|
18 |
get_tokenizer,
|
19 |
convert_char_to_pinyin,
|
File without changes
|
@@ -3,7 +3,7 @@ import os
|
|
3 |
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
-
from model import M2_TTS, DiT
|
7 |
|
8 |
import torch
|
9 |
import thop
|
|
|
3 |
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
+
from f5_tts.model import M2_TTS, DiT
|
7 |
|
8 |
import torch
|
9 |
import thop
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import time
|
7 |
+
import random
|
8 |
+
from tqdm import tqdm
|
9 |
+
import argparse
|
10 |
+
from importlib.resources import files
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torchaudio
|
14 |
+
from accelerate import Accelerator
|
15 |
+
from vocos import Vocos
|
16 |
+
|
17 |
+
from f5_tts.model import CFM, UNetT, DiT
|
18 |
+
from f5_tts.model.utils import (
|
19 |
+
load_checkpoint,
|
20 |
+
get_tokenizer,
|
21 |
+
get_seedtts_testset_metainfo,
|
22 |
+
get_librispeech_test_clean_metainfo,
|
23 |
+
get_inference_prompt,
|
24 |
+
)
|
25 |
+
|
26 |
+
accelerator = Accelerator()
|
27 |
+
device = f"cuda:{accelerator.process_index}"
|
28 |
+
|
29 |
+
|
30 |
+
# --------------------- Dataset Settings -------------------- #
|
31 |
+
|
32 |
+
target_sample_rate = 24000
|
33 |
+
n_mel_channels = 100
|
34 |
+
hop_length = 256
|
35 |
+
target_rms = 0.1
|
36 |
+
|
37 |
+
tokenizer = "pinyin"
|
38 |
+
|
39 |
+
def main():
|
40 |
+
# ---------------------- infer setting ---------------------- #
|
41 |
+
|
42 |
+
parser = argparse.ArgumentParser(description="batch inference")
|
43 |
+
|
44 |
+
parser.add_argument("-s", "--seed", default=None, type=int)
|
45 |
+
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
46 |
+
parser.add_argument("-n", "--expname", required=True)
|
47 |
+
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
48 |
+
|
49 |
+
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
50 |
+
parser.add_argument("-o", "--odemethod", default="euler")
|
51 |
+
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
52 |
+
|
53 |
+
parser.add_argument("-t", "--testset", required=True)
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
|
58 |
+
seed = args.seed
|
59 |
+
dataset_name = args.dataset
|
60 |
+
exp_name = args.expname
|
61 |
+
ckpt_step = args.ckptstep
|
62 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
63 |
+
|
64 |
+
nfe_step = args.nfestep
|
65 |
+
ode_method = args.odemethod
|
66 |
+
sway_sampling_coef = args.swaysampling
|
67 |
+
|
68 |
+
testset = args.testset
|
69 |
+
|
70 |
+
|
71 |
+
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
72 |
+
cfg_strength = 2.0
|
73 |
+
speed = 1.0
|
74 |
+
use_truth_duration = False
|
75 |
+
no_ref_audio = False
|
76 |
+
|
77 |
+
|
78 |
+
if exp_name == "F5TTS_Base":
|
79 |
+
model_cls = DiT
|
80 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
81 |
+
|
82 |
+
elif exp_name == "E2TTS_Base":
|
83 |
+
model_cls = UNetT
|
84 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
85 |
+
|
86 |
+
|
87 |
+
datapath = files('f5_tts').joinpath('data')
|
88 |
+
|
89 |
+
if testset == "ls_pc_test_clean":
|
90 |
+
metalst = os.path.join(datapath,"librispeech_pc_test_clean_cross_sentence.lst")
|
91 |
+
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
92 |
+
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
93 |
+
|
94 |
+
elif testset == "seedtts_test_zh":
|
95 |
+
metalst = os.path.join(datapath,"seedtts_testset/zh/meta.lst")
|
96 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
97 |
+
|
98 |
+
elif testset == "seedtts_test_en":
|
99 |
+
metalst = os.path.join(datapath,"seedtts_testset/en/meta.lst")
|
100 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
101 |
+
|
102 |
+
|
103 |
+
# path to save genereted wavs
|
104 |
+
if seed is None:
|
105 |
+
seed = random.randint(-10000, 10000)
|
106 |
+
output_dir = (
|
107 |
+
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
108 |
+
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
109 |
+
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
110 |
+
f"_cfg{cfg_strength}_speed{speed}"
|
111 |
+
f"{'_gt-dur' if use_truth_duration else ''}"
|
112 |
+
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
# -------------------------------------------------#
|
117 |
+
|
118 |
+
use_ema = True
|
119 |
+
|
120 |
+
prompts_all = get_inference_prompt(
|
121 |
+
metainfo,
|
122 |
+
speed=speed,
|
123 |
+
tokenizer=tokenizer,
|
124 |
+
target_sample_rate=target_sample_rate,
|
125 |
+
n_mel_channels=n_mel_channels,
|
126 |
+
hop_length=hop_length,
|
127 |
+
target_rms=target_rms,
|
128 |
+
use_truth_duration=use_truth_duration,
|
129 |
+
infer_batch_size=infer_batch_size,
|
130 |
+
)
|
131 |
+
|
132 |
+
# Vocoder model
|
133 |
+
local = False
|
134 |
+
if local:
|
135 |
+
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
136 |
+
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
137 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
138 |
+
vocos.load_state_dict(state_dict)
|
139 |
+
vocos.eval()
|
140 |
+
else:
|
141 |
+
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
142 |
+
|
143 |
+
# Tokenizer
|
144 |
+
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
145 |
+
|
146 |
+
# Model
|
147 |
+
model = CFM(
|
148 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
149 |
+
mel_spec_kwargs=dict(
|
150 |
+
target_sample_rate=target_sample_rate,
|
151 |
+
n_mel_channels=n_mel_channels,
|
152 |
+
hop_length=hop_length,
|
153 |
+
),
|
154 |
+
odeint_kwargs=dict(
|
155 |
+
method=ode_method,
|
156 |
+
),
|
157 |
+
vocab_char_map=vocab_char_map,
|
158 |
+
).to(device)
|
159 |
+
|
160 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
161 |
+
|
162 |
+
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
163 |
+
os.makedirs(output_dir)
|
164 |
+
|
165 |
+
# start batch inference
|
166 |
+
accelerator.wait_for_everyone()
|
167 |
+
start = time.time()
|
168 |
+
|
169 |
+
with accelerator.split_between_processes(prompts_all) as prompts:
|
170 |
+
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
171 |
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
172 |
+
ref_mels = ref_mels.to(device)
|
173 |
+
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
174 |
+
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
175 |
+
|
176 |
+
# Inference
|
177 |
+
with torch.inference_mode():
|
178 |
+
generated, _ = model.sample(
|
179 |
+
cond=ref_mels,
|
180 |
+
text=final_text_list,
|
181 |
+
duration=total_mel_lens,
|
182 |
+
lens=ref_mel_lens,
|
183 |
+
steps=nfe_step,
|
184 |
+
cfg_strength=cfg_strength,
|
185 |
+
sway_sampling_coef=sway_sampling_coef,
|
186 |
+
no_ref_audio=no_ref_audio,
|
187 |
+
seed=seed,
|
188 |
+
)
|
189 |
+
# Final result
|
190 |
+
for i, gen in enumerate(generated):
|
191 |
+
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
192 |
+
gen_mel_spec = gen.permute(0, 2, 1)
|
193 |
+
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
194 |
+
if ref_rms_list[i] < target_rms:
|
195 |
+
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
196 |
+
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
197 |
+
|
198 |
+
accelerator.wait_for_everyone()
|
199 |
+
if accelerator.is_main_process:
|
200 |
+
timediff = time.time() - start
|
201 |
+
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
main()
|
File without changes
|
@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
|
|
8 |
import multiprocessing as mp
|
9 |
import numpy as np
|
10 |
|
11 |
-
from model.utils import (
|
12 |
get_librispeech_test,
|
13 |
run_asr_wer,
|
14 |
run_sim,
|
|
|
8 |
import multiprocessing as mp
|
9 |
import numpy as np
|
10 |
|
11 |
+
from f5_tts.model.utils import (
|
12 |
get_librispeech_test,
|
13 |
run_asr_wer,
|
14 |
run_sim,
|
@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
|
|
8 |
import multiprocessing as mp
|
9 |
import numpy as np
|
10 |
|
11 |
-
from model.utils import (
|
12 |
get_seed_tts_test,
|
13 |
run_asr_wer,
|
14 |
run_sim,
|
|
|
8 |
import multiprocessing as mp
|
9 |
import numpy as np
|
10 |
|
11 |
+
from f5_tts.model.utils import (
|
12 |
get_seed_tts_test,
|
13 |
run_asr_wer,
|
14 |
run_sim,
|
@@ -13,7 +13,7 @@ import torchaudio
|
|
13 |
from tqdm import tqdm
|
14 |
from datasets.arrow_writer import ArrowWriter
|
15 |
|
16 |
-
from model.utils import (
|
17 |
convert_char_to_pinyin,
|
18 |
)
|
19 |
|
|
|
13 |
from tqdm import tqdm
|
14 |
from datasets.arrow_writer import ArrowWriter
|
15 |
|
16 |
+
from f5_tts.model.utils import (
|
17 |
convert_char_to_pinyin,
|
18 |
)
|
19 |
|
@@ -16,7 +16,7 @@ from concurrent.futures import ProcessPoolExecutor
|
|
16 |
|
17 |
from datasets.arrow_writer import ArrowWriter
|
18 |
|
19 |
-
from model.utils import (
|
20 |
repetition_found,
|
21 |
convert_char_to_pinyin,
|
22 |
)
|
|
|
16 |
|
17 |
from datasets.arrow_writer import ArrowWriter
|
18 |
|
19 |
+
from f5_tts.model.utils import (
|
20 |
repetition_found,
|
21 |
convert_char_to_pinyin,
|
22 |
)
|
@@ -13,7 +13,7 @@ from concurrent.futures import ProcessPoolExecutor
|
|
13 |
import torchaudio
|
14 |
from datasets import Dataset
|
15 |
|
16 |
-
from model.utils import convert_char_to_pinyin
|
17 |
|
18 |
|
19 |
def deal_with_sub_path_files(dataset_path, sub_path):
|
|
|
13 |
import torchaudio
|
14 |
from datasets import Dataset
|
15 |
|
16 |
+
from f5_tts.model.utils import convert_char_to_pinyin
|
17 |
|
18 |
|
19 |
def deal_with_sub_path_files(dataset_path, sub_path):
|
@@ -5,8 +5,8 @@ import torch.nn.functional as F
|
|
5 |
import torchaudio
|
6 |
from vocos import Vocos
|
7 |
|
8 |
-
from model import CFM, UNetT, DiT
|
9 |
-
from model.utils import (
|
10 |
load_checkpoint,
|
11 |
get_tokenizer,
|
12 |
convert_char_to_pinyin,
|
|
|
5 |
import torchaudio
|
6 |
from vocos import Vocos
|
7 |
|
8 |
+
from f5_tts.model import CFM, UNetT, DiT
|
9 |
+
from f5_tts.model.utils import (
|
10 |
load_checkpoint,
|
11 |
get_tokenizer,
|
12 |
convert_char_to_pinyin,
|
@@ -1,6 +1,6 @@
|
|
1 |
-
from model import CFM, UNetT, DiT, Trainer
|
2 |
-
from model.utils import get_tokenizer
|
3 |
-
from model.dataset import load_dataset
|
4 |
|
5 |
|
6 |
# -------------------------- Dataset Settings --------------------------- #
|
|
|
1 |
+
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
2 |
+
from f5_tts.model.utils import get_tokenizer
|
3 |
+
from f5_tts.model.dataset import load_dataset
|
4 |
|
5 |
|
6 |
# -------------------------- Dataset Settings --------------------------- #
|