SWivid commited on
Commit
37cdbe8
·
1 Parent(s): 6bb2043

formatting, sorting

Browse files
Files changed (40) hide show
  1. .pre-commit-config.yaml +5 -2
  2. ruff.toml +1 -1
  3. src/f5_tts/api.py +2 -2
  4. src/f5_tts/eval/ecapa_tdnn.py +1 -0
  5. src/f5_tts/eval/eval_infer_batch.py +2 -0
  6. src/f5_tts/eval/eval_librispeech_test_clean.py +4 -5
  7. src/f5_tts/eval/eval_seedtts_testset.py +4 -5
  8. src/f5_tts/infer/infer_cli.py +7 -7
  9. src/f5_tts/infer/infer_gradio.py +4 -3
  10. src/f5_tts/infer/speech_edit.py +3 -1
  11. src/f5_tts/infer/utils_infer.py +4 -4
  12. src/f5_tts/model/__init__.py +2 -4
  13. src/f5_tts/model/backbones/dit.py +4 -5
  14. src/f5_tts/model/backbones/mmdit.py +3 -4
  15. src/f5_tts/model/backbones/unett.py +6 -6
  16. src/f5_tts/model/trainer.py +1 -0
  17. src/f5_tts/model/utils.py +2 -3
  18. src/f5_tts/runtime/triton_trtllm/benchmark.py +12 -11
  19. src/f5_tts/runtime/triton_trtllm/client_grpc.py +0 -1
  20. src/f5_tts/runtime/triton_trtllm/client_http.py +3 -2
  21. src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py +6 -7
  22. src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +6 -5
  23. src/f5_tts/runtime/triton_trtllm/patch/__init__.py +3 -2
  24. src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +9 -12
  25. src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +14 -12
  26. src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +1 -0
  27. src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py +0 -1
  28. src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +4 -3
  29. src/f5_tts/scripts/count_params_gflops.py +5 -4
  30. src/f5_tts/socket_client.py +5 -3
  31. src/f5_tts/socket_server.py +5 -4
  32. src/f5_tts/train/datasets/prepare_csv_wavs.py +7 -8
  33. src/f5_tts/train/datasets/prepare_emilia.py +3 -5
  34. src/f5_tts/train/datasets/prepare_emilia_v2.py +6 -6
  35. src/f5_tts/train/datasets/prepare_libritts.py +3 -1
  36. src/f5_tts/train/datasets/prepare_ljspeech.py +3 -1
  37. src/f5_tts/train/datasets/prepare_wenetspeech4tts.py +2 -1
  38. src/f5_tts/train/finetune_cli.py +2 -2
  39. src/f5_tts/train/finetune_gradio.py +5 -5
  40. src/f5_tts/train/train.py +1 -0
.pre-commit-config.yaml CHANGED
@@ -3,11 +3,14 @@ repos:
3
  # Ruff version.
4
  rev: v0.11.2
5
  hooks:
6
- # Run the linter.
7
  - id: ruff
 
8
  args: [--fix]
9
- # Run the formatter.
10
  - id: ruff-format
 
 
 
 
11
  - repo: https://github.com/pre-commit/pre-commit-hooks
12
  rev: v5.0.0
13
  hooks:
 
3
  # Ruff version.
4
  rev: v0.11.2
5
  hooks:
 
6
  - id: ruff
7
+ name: ruff linter
8
  args: [--fix]
 
9
  - id: ruff-format
10
+ name: ruff formatter
11
+ - id: ruff
12
+ name: ruff sorter
13
+ args: [--select, I, --fix]
14
  - repo: https://github.com/pre-commit/pre-commit-hooks
15
  rev: v5.0.0
16
  hooks:
ruff.toml CHANGED
@@ -6,5 +6,5 @@ target-version = "py310"
6
  dummy-variable-rgx = "^_.*$"
7
 
8
  [lint.isort]
9
- force-single-line = true
10
  lines-after-imports = 2
 
6
  dummy-variable-rgx = "^_.*$"
7
 
8
  [lint.isort]
9
+ force-single-line = false
10
  lines-after-imports = 2
src/f5_tts/api.py CHANGED
@@ -9,13 +9,13 @@ from hydra.utils import get_class
9
  from omegaconf import OmegaConf
10
 
11
  from f5_tts.infer.utils_infer import (
 
12
  load_model,
13
  load_vocoder,
14
- transcribe,
15
  preprocess_ref_audio_text,
16
- infer_process,
17
  remove_silence_for_generated_wav,
18
  save_spectrogram,
 
19
  )
20
  from f5_tts.model.utils import seed_everything
21
 
 
9
  from omegaconf import OmegaConf
10
 
11
  from f5_tts.infer.utils_infer import (
12
+ infer_process,
13
  load_model,
14
  load_vocoder,
 
15
  preprocess_ref_audio_text,
 
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
18
+ transcribe,
19
  )
20
  from f5_tts.model.utils import seed_everything
21
 
src/f5_tts/eval/ecapa_tdnn.py CHANGED
@@ -4,6 +4,7 @@
4
  # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
 
6
  import os
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
 
4
  # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
 
6
  import os
7
+
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import sys
3
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import argparse
@@ -23,6 +24,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
23
  from f5_tts.model import CFM
24
  from f5_tts.model.utils import get_tokenizer
25
 
 
26
  accelerator = Accelerator()
27
  device = f"cuda:{accelerator.process_index}"
28
 
 
1
  import os
2
  import sys
3
 
4
+
5
  sys.path.append(os.getcwd())
6
 
7
  import argparse
 
24
  from f5_tts.model import CFM
25
  from f5_tts.model.utils import get_tokenizer
26
 
27
+
28
  accelerator = Accelerator()
29
  device = f"cuda:{accelerator.process_index}"
30
 
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -5,17 +5,16 @@ import json
5
  import os
6
  import sys
7
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  import multiprocessing as mp
11
  from importlib.resources import files
12
 
13
  import numpy as np
14
- from f5_tts.eval.utils_eval import (
15
- get_librispeech_test,
16
- run_asr_wer,
17
- run_sim,
18
- )
19
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
 
5
  import os
6
  import sys
7
 
8
+
9
  sys.path.append(os.getcwd())
10
 
11
  import multiprocessing as mp
12
  from importlib.resources import files
13
 
14
  import numpy as np
15
+
16
+ from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
17
+
 
 
18
 
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -5,17 +5,16 @@ import json
5
  import os
6
  import sys
7
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  import multiprocessing as mp
11
  from importlib.resources import files
12
 
13
  import numpy as np
14
- from f5_tts.eval.utils_eval import (
15
- get_seed_tts_test,
16
- run_asr_wer,
17
- run_sim,
18
- )
19
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
 
5
  import os
6
  import sys
7
 
8
+
9
  sys.path.append(os.getcwd())
10
 
11
  import multiprocessing as mp
12
  from importlib.resources import files
13
 
14
  import numpy as np
15
+
16
+ from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
17
+
 
 
18
 
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
src/f5_tts/infer/infer_cli.py CHANGED
@@ -14,20 +14,20 @@ from hydra.utils import get_class
14
  from omegaconf import OmegaConf
15
 
16
  from f5_tts.infer.utils_infer import (
17
- mel_spec_type,
18
- target_rms,
19
- cross_fade_duration,
20
- nfe_step,
21
  cfg_strength,
22
- sway_sampling_coef,
23
- speed,
24
- fix_duration,
25
  device,
 
26
  infer_process,
27
  load_model,
28
  load_vocoder,
 
 
29
  preprocess_ref_audio_text,
30
  remove_silence_for_generated_wav,
 
 
 
31
  )
32
 
33
 
 
14
  from omegaconf import OmegaConf
15
 
16
  from f5_tts.infer.utils_infer import (
 
 
 
 
17
  cfg_strength,
18
+ cross_fade_duration,
 
 
19
  device,
20
+ fix_duration,
21
  infer_process,
22
  load_model,
23
  load_vocoder,
24
+ mel_spec_type,
25
+ nfe_step,
26
  preprocess_ref_audio_text,
27
  remove_silence_for_generated_wav,
28
+ speed,
29
+ sway_sampling_coef,
30
+ target_rms,
31
  )
32
 
33
 
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -18,6 +18,7 @@ import torchaudio
18
  from cached_path import cached_path
19
  from transformers import AutoModelForCausalLM, AutoTokenizer
20
 
 
21
  try:
22
  import spaces
23
 
@@ -33,15 +34,15 @@ def gpu_decorator(func):
33
  return func
34
 
35
 
36
- from f5_tts.model import DiT, UNetT
37
  from f5_tts.infer.utils_infer import (
38
- load_vocoder,
39
  load_model,
 
40
  preprocess_ref_audio_text,
41
- infer_process,
42
  remove_silence_for_generated_wav,
43
  save_spectrogram,
44
  )
 
45
 
46
 
47
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
 
18
  from cached_path import cached_path
19
  from transformers import AutoModelForCausalLM, AutoTokenizer
20
 
21
+
22
  try:
23
  import spaces
24
 
 
34
  return func
35
 
36
 
 
37
  from f5_tts.infer.utils_infer import (
38
+ infer_process,
39
  load_model,
40
+ load_vocoder,
41
  preprocess_ref_audio_text,
 
42
  remove_silence_for_generated_wav,
43
  save_spectrogram,
44
  )
45
+ from f5_tts.model import DiT, UNetT
46
 
47
 
48
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
src/f5_tts/infer/speech_edit.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
 
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
 
5
  from importlib.resources import files
@@ -7,14 +8,15 @@ from importlib.resources import files
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
 
10
  from hydra.utils import get_class
11
  from omegaconf import OmegaConf
12
- from cached_path import cached_path
13
 
14
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
15
  from f5_tts.model import CFM
16
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
17
 
 
18
  device = (
19
  "cuda"
20
  if torch.cuda.is_available()
 
1
  import os
2
 
3
+
4
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
5
 
6
  from importlib.resources import files
 
8
  import torch
9
  import torch.nn.functional as F
10
  import torchaudio
11
+ from cached_path import cached_path
12
  from hydra.utils import get_class
13
  from omegaconf import OmegaConf
 
14
 
15
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
16
  from f5_tts.model import CFM
17
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
18
 
19
+
20
  device = (
21
  "cuda"
22
  if torch.cuda.is_available()
src/f5_tts/infer/utils_infer.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import sys
5
  from concurrent.futures import ThreadPoolExecutor
6
 
 
7
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
8
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
9
 
@@ -14,6 +15,7 @@ from importlib.resources import files
14
 
15
  import matplotlib
16
 
 
17
  matplotlib.use("Agg")
18
 
19
  import matplotlib.pylab as plt
@@ -27,10 +29,8 @@ from transformers import pipeline
27
  from vocos import Vocos
28
 
29
  from f5_tts.model import CFM
30
- from f5_tts.model.utils import (
31
- get_tokenizer,
32
- convert_char_to_pinyin,
33
- )
34
 
35
  _ref_audio_cache = {}
36
 
 
4
  import sys
5
  from concurrent.futures import ThreadPoolExecutor
6
 
7
+
8
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
9
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
10
 
 
15
 
16
  import matplotlib
17
 
18
+
19
  matplotlib.use("Agg")
20
 
21
  import matplotlib.pylab as plt
 
29
  from vocos import Vocos
30
 
31
  from f5_tts.model import CFM
32
+ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
33
+
 
 
34
 
35
  _ref_audio_cache = {}
36
 
src/f5_tts/model/__init__.py CHANGED
@@ -1,9 +1,7 @@
1
- from f5_tts.model.cfm import CFM
2
-
3
- from f5_tts.model.backbones.unett import UNetT
4
  from f5_tts.model.backbones.dit import DiT
5
  from f5_tts.model.backbones.mmdit import MMDiT
6
-
 
7
  from f5_tts.model.trainer import Trainer
8
 
9
 
 
 
 
 
1
  from f5_tts.model.backbones.dit import DiT
2
  from f5_tts.model.backbones.mmdit import MMDiT
3
+ from f5_tts.model.backbones.unett import UNetT
4
+ from f5_tts.model.cfm import CFM
5
  from f5_tts.model.trainer import Trainer
6
 
7
 
src/f5_tts/model/backbones/dit.py CHANGED
@@ -10,19 +10,18 @@ d - dimension
10
  from __future__ import annotations
11
 
12
  import torch
13
- from torch import nn
14
  import torch.nn.functional as F
15
-
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
  from f5_tts.model.modules import (
19
- TimestepEmbedding,
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
- AdaLayerNorm_Final,
24
- precompute_freqs_cis,
25
  get_pos_embed_indices,
 
26
  )
27
 
28
 
 
10
  from __future__ import annotations
11
 
12
  import torch
 
13
  import torch.nn.functional as F
14
+ from torch import nn
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from f5_tts.model.modules import (
18
+ AdaLayerNorm_Final,
19
  ConvNeXtV2Block,
20
  ConvPositionEmbedding,
21
  DiTBlock,
22
+ TimestepEmbedding,
 
23
  get_pos_embed_indices,
24
+ precompute_freqs_cis,
25
  )
26
 
27
 
src/f5_tts/model/backbones/mmdit.py CHANGED
@@ -11,16 +11,15 @@ from __future__ import annotations
11
 
12
  import torch
13
  from torch import nn
14
-
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from f5_tts.model.modules import (
18
- TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
- AdaLayerNorm_Final,
22
- precompute_freqs_cis,
23
  get_pos_embed_indices,
 
24
  )
25
 
26
 
 
11
 
12
  import torch
13
  from torch import nn
 
14
  from x_transformers.x_transformers import RotaryEmbedding
15
 
16
  from f5_tts.model.modules import (
17
+ AdaLayerNorm_Final,
18
  ConvPositionEmbedding,
19
  MMDiTBlock,
20
+ TimestepEmbedding,
 
21
  get_pos_embed_indices,
22
+ precompute_freqs_cis,
23
  )
24
 
25
 
src/f5_tts/model/backbones/unett.py CHANGED
@@ -8,24 +8,24 @@ d - dimension
8
  """
9
 
10
  from __future__ import annotations
 
11
  from typing import Literal
12
 
13
  import torch
14
- from torch import nn
15
  import torch.nn.functional as F
16
-
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from f5_tts.model.modules import (
21
- TimestepEmbedding,
22
- ConvNeXtV2Block,
23
- ConvPositionEmbedding,
24
  Attention,
25
  AttnProcessor,
 
 
26
  FeedForward,
27
- precompute_freqs_cis,
28
  get_pos_embed_indices,
 
29
  )
30
 
31
 
 
8
  """
9
 
10
  from __future__ import annotations
11
+
12
  from typing import Literal
13
 
14
  import torch
 
15
  import torch.nn.functional as F
16
+ from torch import nn
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from f5_tts.model.modules import (
 
 
 
21
  Attention,
22
  AttnProcessor,
23
+ ConvNeXtV2Block,
24
+ ConvPositionEmbedding,
25
  FeedForward,
26
+ TimestepEmbedding,
27
  get_pos_embed_indices,
28
+ precompute_freqs_cis,
29
  )
30
 
31
 
src/f5_tts/model/trainer.py CHANGED
@@ -19,6 +19,7 @@ from f5_tts.model import CFM
19
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
20
  from f5_tts.model.utils import default, exists
21
 
 
22
  # trainer
23
 
24
 
 
19
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
20
  from f5_tts.model.utils import default, exists
21
 
22
+
23
  # trainer
24
 
25
 
src/f5_tts/model/utils.py CHANGED
@@ -5,12 +5,11 @@ import random
5
  from collections import defaultdict
6
  from importlib.resources import files
7
 
 
8
  import torch
 
9
  from torch.nn.utils.rnn import pad_sequence
10
 
11
- import jieba
12
- from pypinyin import lazy_pinyin, Style
13
-
14
 
15
  # seed everything
16
 
 
5
  from collections import defaultdict
6
  from importlib.resources import files
7
 
8
+ import jieba
9
  import torch
10
+ from pypinyin import Style, lazy_pinyin
11
  from torch.nn.utils.rnn import pad_sequence
12
 
 
 
 
13
 
14
  # seed everything
15
 
src/f5_tts/runtime/triton_trtllm/benchmark.py CHANGED
@@ -30,26 +30,27 @@ import argparse
30
  import json
31
  import os
32
  import time
33
- from typing import List, Dict, Union
34
 
 
 
 
35
  import torch
36
  import torch.distributed as dist
37
  import torch.nn.functional as F
38
- from torch.nn.utils.rnn import pad_sequence
39
  import torchaudio
40
- import jieba
41
- from pypinyin import Style, lazy_pinyin
42
  from datasets import load_dataset
43
- import datasets
44
  from huggingface_hub import hf_hub_download
 
 
 
 
 
45
  from torch.utils.data import DataLoader, DistributedSampler
46
  from tqdm import tqdm
47
  from vocos import Vocos
48
- from f5_tts_trtllm import F5TTS
49
- import tensorrt as trt
50
- from tensorrt_llm.runtime.session import Session, TensorInfo
51
- from tensorrt_llm.logger import logger
52
- from tensorrt_llm._utils import trt_dtype_to_torch
53
 
54
  torch.manual_seed(0)
55
 
@@ -381,8 +382,8 @@ def main():
381
  import sys
382
 
383
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
384
- from f5_tts.model import DiT
385
  from f5_tts.infer.utils_infer import load_model
 
386
 
387
  F5TTS_model_cfg = dict(
388
  dim=1024,
 
30
  import json
31
  import os
32
  import time
33
+ from typing import Dict, List, Union
34
 
35
+ import datasets
36
+ import jieba
37
+ import tensorrt as trt
38
  import torch
39
  import torch.distributed as dist
40
  import torch.nn.functional as F
 
41
  import torchaudio
 
 
42
  from datasets import load_dataset
43
+ from f5_tts_trtllm import F5TTS
44
  from huggingface_hub import hf_hub_download
45
+ from pypinyin import Style, lazy_pinyin
46
+ from tensorrt_llm._utils import trt_dtype_to_torch
47
+ from tensorrt_llm.logger import logger
48
+ from tensorrt_llm.runtime.session import Session, TensorInfo
49
+ from torch.nn.utils.rnn import pad_sequence
50
  from torch.utils.data import DataLoader, DistributedSampler
51
  from tqdm import tqdm
52
  from vocos import Vocos
53
+
 
 
 
 
54
 
55
  torch.manual_seed(0)
56
 
 
382
  import sys
383
 
384
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
 
385
  from f5_tts.infer.utils_infer import load_model
386
+ from f5_tts.model import DiT
387
 
388
  F5TTS_model_cfg = dict(
389
  dim=1024,
src/f5_tts/runtime/triton_trtllm/client_grpc.py CHANGED
@@ -44,7 +44,6 @@ python3 client_grpc.py \
44
  import argparse
45
  import asyncio
46
  import json
47
-
48
  import os
49
  import time
50
  import types
 
44
  import argparse
45
  import asyncio
46
  import json
 
47
  import os
48
  import time
49
  import types
src/f5_tts/runtime/triton_trtllm/client_http.py CHANGED
@@ -23,10 +23,11 @@
23
  # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
26
  import requests
27
  import soundfile as sf
28
- import numpy as np
29
- import argparse
30
 
31
 
32
  def get_args():
 
23
  # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import argparse
27
+
28
+ import numpy as np
29
  import requests
30
  import soundfile as sf
 
 
31
 
32
 
33
  def get_args():
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py CHANGED
@@ -1,18 +1,17 @@
1
- import tensorrt as trt
2
- import os
3
  import math
 
4
  import time
5
- from typing import List, Optional
6
  from functools import wraps
 
7
 
 
8
  import tensorrt_llm
9
- from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
10
- from tensorrt_llm.logger import logger
11
- from tensorrt_llm.runtime.session import Session
12
-
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
 
 
 
16
 
17
 
18
  def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
 
 
 
1
  import math
2
+ import os
3
  import time
 
4
  from functools import wraps
5
+ from typing import List, Optional
6
 
7
+ import tensorrt as trt
8
  import tensorrt_llm
 
 
 
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
13
+ from tensorrt_llm.logger import logger
14
+ from tensorrt_llm.runtime.session import Session
15
 
16
 
17
  def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py CHANGED
@@ -24,16 +24,17 @@
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
  import json
 
 
 
27
  import torch
28
- from torch.nn.utils.rnn import pad_sequence
29
  import torch.nn.functional as F
30
- from torch.utils.dlpack import from_dlpack, to_dlpack
31
  import torchaudio
32
- import jieba
33
  import triton_python_backend_utils as pb_utils
34
- from pypinyin import Style, lazy_pinyin
35
- import os
36
  from f5_tts_trtllm import F5TTS
 
 
 
37
 
38
 
39
  def get_tokenizer(vocab_file_path: str):
 
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
  import json
27
+ import os
28
+
29
+ import jieba
30
  import torch
 
31
  import torch.nn.functional as F
 
32
  import torchaudio
 
33
  import triton_python_backend_utils as pb_utils
 
 
34
  from f5_tts_trtllm import F5TTS
35
+ from pypinyin import Style, lazy_pinyin
36
+ from torch.nn.utils.rnn import pad_sequence
37
+ from torch.utils.dlpack import from_dlpack, to_dlpack
38
 
39
 
40
  def get_tokenizer(vocab_file_path: str):
src/f5_tts/runtime/triton_trtllm/patch/__init__.py CHANGED
@@ -34,6 +34,7 @@ from .deepseek_v2.model import DeepseekV2ForCausalLM
34
  from .dit.model import DiT
35
  from .eagle.model import EagleForCausalLM
36
  from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
 
37
  from .falcon.config import FalconConfig
38
  from .falcon.model import FalconForCausalLM, FalconModel
39
  from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
@@ -54,12 +55,12 @@ from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodi
54
  from .mpt.model import MPTForCausalLM, MPTModel
55
  from .nemotron_nas.model import DeciLMForCausalLM
56
  from .opt.model import OPTForCausalLM, OPTModel
57
- from .phi3.model import Phi3ForCausalLM, Phi3Model
58
  from .phi.model import PhiForCausalLM, PhiModel
 
59
  from .qwen.model import QWenForCausalLM
60
  from .recurrentgemma.model import RecurrentGemmaForCausalLM
61
  from .redrafter.model import ReDrafterForCausalLM
62
- from .f5tts.model import F5TTS
63
 
64
  __all__ = [
65
  "BertModel",
 
34
  from .dit.model import DiT
35
  from .eagle.model import EagleForCausalLM
36
  from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
37
+ from .f5tts.model import F5TTS
38
  from .falcon.config import FalconConfig
39
  from .falcon.model import FalconForCausalLM, FalconModel
40
  from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
 
55
  from .mpt.model import MPTForCausalLM, MPTModel
56
  from .nemotron_nas.model import DeciLMForCausalLM
57
  from .opt.model import OPTForCausalLM, OPTModel
 
58
  from .phi.model import PhiForCausalLM, PhiModel
59
+ from .phi3.model import Phi3ForCausalLM, Phi3Model
60
  from .qwen.model import QWenForCausalLM
61
  from .recurrentgemma.model import RecurrentGemmaForCausalLM
62
  from .redrafter.model import ReDrafterForCausalLM
63
+
64
 
65
  __all__ = [
66
  "BertModel",
src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py CHANGED
@@ -1,23 +1,20 @@
1
  from __future__ import annotations
2
- import sys
3
  import os
 
 
4
 
5
  import tensorrt as trt
6
- from collections import OrderedDict
 
7
  from ..._utils import str_dtype_to_trt
8
- from ...plugin import current_all_reduce_helper
9
- from ..modeling_utils import PretrainedConfig, PretrainedModel
10
  from ...functional import Tensor, concat
11
- from ...module import Module, ModuleList
12
- from tensorrt_llm._common import default_net
13
  from ...layers import Linear
 
 
 
 
14
 
15
- from .modules import (
16
- TimestepEmbedding,
17
- ConvPositionEmbedding,
18
- DiTBlock,
19
- AdaLayerNormZero_Final,
20
- )
21
 
22
  current_file_path = os.path.abspath(__file__)
23
  parent_dir = os.path.dirname(current_file_path)
 
1
  from __future__ import annotations
2
+
3
  import os
4
+ import sys
5
+ from collections import OrderedDict
6
 
7
  import tensorrt as trt
8
+ from tensorrt_llm._common import default_net
9
+
10
  from ..._utils import str_dtype_to_trt
 
 
11
  from ...functional import Tensor, concat
 
 
12
  from ...layers import Linear
13
+ from ...module import Module, ModuleList
14
+ from ...plugin import current_all_reduce_helper
15
+ from ..modeling_utils import PretrainedConfig, PretrainedModel
16
+ from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
17
 
 
 
 
 
 
 
18
 
19
  current_file_path = os.path.abspath(__file__)
20
  parent_dir = os.path.dirname(current_file_path)
src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py CHANGED
@@ -3,33 +3,35 @@ from __future__ import annotations
3
  import math
4
  from typing import Optional
5
 
 
6
  import torch
7
  import torch.nn.functional as F
8
-
9
- import numpy as np
10
  from tensorrt_llm._common import default_net
11
- from ..._utils import trt_dtype_to_np, str_dtype_to_trt
 
12
  from ...functional import (
13
  Tensor,
 
 
14
  chunk,
15
  concat,
16
  constant,
17
  expand,
 
 
 
 
 
 
18
  shape,
19
  silu,
20
  slice,
21
- permute,
22
- expand_mask,
23
- expand_dims_like,
24
- unsqueeze,
25
- matmul,
26
  softmax,
27
  squeeze,
28
- cast,
29
- gelu,
30
  )
31
- from ...functional import expand_dims, view, bert_attention
32
- from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
33
  from ...module import Module
34
 
35
 
 
3
  import math
4
  from typing import Optional
5
 
6
+ import numpy as np
7
  import torch
8
  import torch.nn.functional as F
 
 
9
  from tensorrt_llm._common import default_net
10
+
11
+ from ..._utils import str_dtype_to_trt, trt_dtype_to_np
12
  from ...functional import (
13
  Tensor,
14
+ bert_attention,
15
+ cast,
16
  chunk,
17
  concat,
18
  constant,
19
  expand,
20
+ expand_dims,
21
+ expand_dims_like,
22
+ expand_mask,
23
+ gelu,
24
+ matmul,
25
+ permute,
26
  shape,
27
  silu,
28
  slice,
 
 
 
 
 
29
  softmax,
30
  squeeze,
31
+ unsqueeze,
32
+ view,
33
  )
34
+ from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
 
35
  from ...module import Module
36
 
37
 
src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py CHANGED
@@ -40,6 +40,7 @@ import torch as th
40
  import torch.nn.functional as F
41
  from scipy.signal import check_COLA, get_window
42
 
 
43
  support_clp_op = None
44
  if th.__version__ >= "1.7.0":
45
  from torch.fft import rfft as fft
 
40
  import torch.nn.functional as F
41
  from scipy.signal import check_COLA, get_window
42
 
43
+
44
  support_clp_op = None
45
  if th.__version__ >= "1.7.0":
46
  from torch.fft import rfft as fft
src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py CHANGED
@@ -8,7 +8,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
8
 
9
  import safetensors.torch
10
  import torch
11
-
12
  from tensorrt_llm import str_dtype_to_torch
13
  from tensorrt_llm.mapping import Mapping
14
  from tensorrt_llm.models.convert_utils import split, split_matrix_tp
 
8
 
9
  import safetensors.torch
10
  import torch
 
11
  from tensorrt_llm import str_dtype_to_torch
12
  from tensorrt_llm.mapping import Mapping
13
  from tensorrt_llm.models.convert_utils import split, split_matrix_tp
src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py CHANGED
@@ -12,13 +12,14 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import torch
16
  import torch.nn as nn
17
- from huggingface_hub import hf_hub_download
18
-
19
  from conv_stft import STFT
 
20
  from vocos import Vocos
21
- import argparse
22
 
23
  opset_version = 17
24
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import argparse
16
+
17
  import torch
18
  import torch.nn as nn
 
 
19
  from conv_stft import STFT
20
+ from huggingface_hub import hf_hub_download
21
  from vocos import Vocos
22
+
23
 
24
  opset_version = 17
25
 
src/f5_tts/scripts/count_params_gflops.py CHANGED
@@ -1,12 +1,13 @@
1
- import sys
2
  import os
 
3
 
4
- sys.path.append(os.getcwd())
5
 
6
- from f5_tts.model import CFM, DiT
7
 
8
- import torch
9
  import thop
 
 
 
10
 
11
 
12
  """ ~155M """
 
 
1
  import os
2
+ import sys
3
 
 
4
 
5
+ sys.path.append(os.getcwd())
6
 
 
7
  import thop
8
+ import torch
9
+
10
+ from f5_tts.model import CFM, DiT
11
 
12
 
13
  """ ~155M """
src/f5_tts/socket_client.py CHANGED
@@ -1,10 +1,12 @@
1
- import socket
2
  import asyncio
3
- import pyaudio
4
- import numpy as np
5
  import logging
 
6
  import time
7
 
 
 
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
 
 
1
  import asyncio
 
 
2
  import logging
3
+ import socket
4
  import time
5
 
6
+ import numpy as np
7
+ import pyaudio
8
+
9
+
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
src/f5_tts/socket_server.py CHANGED
@@ -1,7 +1,6 @@
1
  import argparse
2
  import gc
3
  import logging
4
- import numpy as np
5
  import queue
6
  import socket
7
  import struct
@@ -10,6 +9,7 @@ import traceback
10
  import wave
11
  from importlib.resources import files
12
 
 
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
@@ -18,12 +18,13 @@ from omegaconf import OmegaConf
18
 
19
  from f5_tts.infer.utils_infer import (
20
  chunk_text,
21
- preprocess_ref_audio_text,
22
- load_vocoder,
23
- load_model,
24
  infer_batch_process,
 
 
 
25
  )
26
 
 
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
 
1
  import argparse
2
  import gc
3
  import logging
 
4
  import queue
5
  import socket
6
  import struct
 
9
  import wave
10
  from importlib.resources import files
11
 
12
+ import numpy as np
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
 
18
 
19
  from f5_tts.infer.utils_infer import (
20
  chunk_text,
 
 
 
21
  infer_batch_process,
22
+ load_model,
23
+ load_vocoder,
24
+ preprocess_ref_audio_text,
25
  )
26
 
27
+
28
  logging.basicConfig(level=logging.INFO)
29
  logger = logging.getLogger(__name__)
30
 
src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -1,12 +1,13 @@
 
 
1
  import os
2
- import sys
3
  import signal
4
  import subprocess # For invoking ffprobe
5
- import shutil
6
- import concurrent.futures
7
- import multiprocessing
8
  from contextlib import contextmanager
9
 
 
10
  sys.path.append(os.getcwd())
11
 
12
  import argparse
@@ -16,12 +17,10 @@ from importlib.resources import files
16
  from pathlib import Path
17
 
18
  import torchaudio
19
- from tqdm import tqdm
20
  from datasets.arrow_writer import ArrowWriter
 
21
 
22
- from f5_tts.model.utils import (
23
- convert_char_to_pinyin,
24
- )
25
 
26
 
27
  PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
 
1
+ import concurrent.futures
2
+ import multiprocessing
3
  import os
4
+ import shutil
5
  import signal
6
  import subprocess # For invoking ffprobe
7
+ import sys
 
 
8
  from contextlib import contextmanager
9
 
10
+
11
  sys.path.append(os.getcwd())
12
 
13
  import argparse
 
17
  from pathlib import Path
18
 
19
  import torchaudio
 
20
  from datasets.arrow_writer import ArrowWriter
21
+ from tqdm import tqdm
22
 
23
+ from f5_tts.model.utils import convert_char_to_pinyin
 
 
24
 
25
 
26
  PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
src/f5_tts/train/datasets/prepare_emilia.py CHANGED
@@ -7,20 +7,18 @@
7
  import os
8
  import sys
9
 
 
10
  sys.path.append(os.getcwd())
11
 
12
  import json
13
  from concurrent.futures import ProcessPoolExecutor
14
  from importlib.resources import files
15
  from pathlib import Path
16
- from tqdm import tqdm
17
 
18
  from datasets.arrow_writer import ArrowWriter
 
19
 
20
- from f5_tts.model.utils import (
21
- repetition_found,
22
- convert_char_to_pinyin,
23
- )
24
 
25
 
26
  out_zh = {
 
7
  import os
8
  import sys
9
 
10
+
11
  sys.path.append(os.getcwd())
12
 
13
  import json
14
  from concurrent.futures import ProcessPoolExecutor
15
  from importlib.resources import files
16
  from pathlib import Path
 
17
 
18
  from datasets.arrow_writer import ArrowWriter
19
+ from tqdm import tqdm
20
 
21
+ from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
 
 
 
22
 
23
 
24
  out_zh = {
src/f5_tts/train/datasets/prepare_emilia_v2.py CHANGED
@@ -1,17 +1,17 @@
1
  # put in src/f5_tts/train/datasets/prepare_emilia_v2.py
2
  # prepares Emilia dataset with the new format w/ Emilia-YODAS
3
 
4
- import os
5
  import json
 
6
  from concurrent.futures import ProcessPoolExecutor
 
7
  from pathlib import Path
8
- from tqdm import tqdm
9
  from datasets.arrow_writer import ArrowWriter
10
- from importlib.resources import files
 
 
11
 
12
- from f5_tts.model.utils import (
13
- repetition_found,
14
- )
15
 
16
  # Define filters for exclusion
17
  out_en = set()
 
1
  # put in src/f5_tts/train/datasets/prepare_emilia_v2.py
2
  # prepares Emilia dataset with the new format w/ Emilia-YODAS
3
 
 
4
  import json
5
+ import os
6
  from concurrent.futures import ProcessPoolExecutor
7
+ from importlib.resources import files
8
  from pathlib import Path
9
+
10
  from datasets.arrow_writer import ArrowWriter
11
+ from tqdm import tqdm
12
+
13
+ from f5_tts.model.utils import repetition_found
14
 
 
 
 
15
 
16
  # Define filters for exclusion
17
  out_en = set()
src/f5_tts/train/datasets/prepare_libritts.py CHANGED
@@ -1,15 +1,17 @@
1
  import os
2
  import sys
3
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import json
7
  from concurrent.futures import ProcessPoolExecutor
8
  from importlib.resources import files
9
  from pathlib import Path
10
- from tqdm import tqdm
11
  import soundfile as sf
12
  from datasets.arrow_writer import ArrowWriter
 
13
 
14
 
15
  def deal_with_audio_dir(audio_dir):
 
1
  import os
2
  import sys
3
 
4
+
5
  sys.path.append(os.getcwd())
6
 
7
  import json
8
  from concurrent.futures import ProcessPoolExecutor
9
  from importlib.resources import files
10
  from pathlib import Path
11
+
12
  import soundfile as sf
13
  from datasets.arrow_writer import ArrowWriter
14
+ from tqdm import tqdm
15
 
16
 
17
  def deal_with_audio_dir(audio_dir):
src/f5_tts/train/datasets/prepare_ljspeech.py CHANGED
@@ -1,14 +1,16 @@
1
  import os
2
  import sys
3
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import json
7
  from importlib.resources import files
8
  from pathlib import Path
9
- from tqdm import tqdm
10
  import soundfile as sf
11
  from datasets.arrow_writer import ArrowWriter
 
12
 
13
 
14
  def main():
 
1
  import os
2
  import sys
3
 
4
+
5
  sys.path.append(os.getcwd())
6
 
7
  import json
8
  from importlib.resources import files
9
  from pathlib import Path
10
+
11
  import soundfile as sf
12
  from datasets.arrow_writer import ArrowWriter
13
+ from tqdm import tqdm
14
 
15
 
16
  def main():
src/f5_tts/train/datasets/prepare_wenetspeech4tts.py CHANGED
@@ -4,15 +4,16 @@
4
  import os
5
  import sys
6
 
 
7
  sys.path.append(os.getcwd())
8
 
9
  import json
10
  from concurrent.futures import ProcessPoolExecutor
11
  from importlib.resources import files
12
- from tqdm import tqdm
13
 
14
  import torchaudio
15
  from datasets import Dataset
 
16
 
17
  from f5_tts.model.utils import convert_char_to_pinyin
18
 
 
4
  import os
5
  import sys
6
 
7
+
8
  sys.path.append(os.getcwd())
9
 
10
  import json
11
  from concurrent.futures import ProcessPoolExecutor
12
  from importlib.resources import files
 
13
 
14
  import torchaudio
15
  from datasets import Dataset
16
+ from tqdm import tqdm
17
 
18
  from f5_tts.model.utils import convert_char_to_pinyin
19
 
src/f5_tts/train/finetune_cli.py CHANGED
@@ -5,9 +5,9 @@ from importlib.resources import files
5
 
6
  from cached_path import cached_path
7
 
8
- from f5_tts.model import CFM, UNetT, DiT, Trainer
9
- from f5_tts.model.utils import get_tokenizer
10
  from f5_tts.model.dataset import load_dataset
 
11
 
12
 
13
  # -------------------------- Dataset Settings --------------------------- #
 
5
 
6
  from cached_path import cached_path
7
 
8
+ from f5_tts.model import CFM, DiT, Trainer, UNetT
 
9
  from f5_tts.model.dataset import load_dataset
10
+ from f5_tts.model.utils import get_tokenizer
11
 
12
 
13
  # -------------------------- Dataset Settings --------------------------- #
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1,14 +1,12 @@
1
  import gc
2
  import json
3
- import numpy as np
4
  import os
5
  import platform
6
- import psutil
7
  import queue
8
  import random
9
  import re
10
- import signal
11
  import shutil
 
12
  import subprocess
13
  import sys
14
  import tempfile
@@ -16,21 +14,23 @@ import threading
16
  import time
17
  from glob import glob
18
  from importlib.resources import files
19
- from scipy.io import wavfile
20
 
21
  import click
22
  import gradio as gr
23
  import librosa
 
 
24
  import torch
25
  import torchaudio
26
  from cached_path import cached_path
27
  from datasets import Dataset as Dataset_
28
  from datasets.arrow_writer import ArrowWriter
29
  from safetensors.torch import load_file, save_file
 
30
 
31
  from f5_tts.api import F5TTS
32
- from f5_tts.model.utils import convert_char_to_pinyin
33
  from f5_tts.infer.utils_infer import transcribe
 
34
 
35
 
36
  training_process = None
 
1
  import gc
2
  import json
 
3
  import os
4
  import platform
 
5
  import queue
6
  import random
7
  import re
 
8
  import shutil
9
+ import signal
10
  import subprocess
11
  import sys
12
  import tempfile
 
14
  import time
15
  from glob import glob
16
  from importlib.resources import files
 
17
 
18
  import click
19
  import gradio as gr
20
  import librosa
21
+ import numpy as np
22
+ import psutil
23
  import torch
24
  import torchaudio
25
  from cached_path import cached_path
26
  from datasets import Dataset as Dataset_
27
  from datasets.arrow_writer import ArrowWriter
28
  from safetensors.torch import load_file, save_file
29
+ from scipy.io import wavfile
30
 
31
  from f5_tts.api import F5TTS
 
32
  from f5_tts.infer.utils_infer import transcribe
33
+ from f5_tts.model.utils import convert_char_to_pinyin
34
 
35
 
36
  training_process = None
src/f5_tts/train/train.py CHANGED
@@ -10,6 +10,7 @@ from f5_tts.model import CFM, Trainer
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
 
13
  os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
14
 
15
 
 
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
13
+
14
  os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
15
 
16