SWivid commited on
Commit
9ee0510
·
1 Parent(s): ef6b576

fix inference-cli; clean-up

Browse files
README.md CHANGED
@@ -58,38 +58,28 @@ Once your datasets are prepared, you can start the training process.
58
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
59
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
60
  accelerate config
61
- accelerate launch test_train.py
62
  ```
63
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
64
 
65
  ## Inference
66
 
67
- To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS).
68
 
69
- Currently support up to 30s generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by Gradio APP now.
70
  - To avoid possible inference failures, make sure you have seen through the following instructions.
71
- - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider split your text and do several separate inferences or leverage the local Gradio APP which enables a batch inference with chunks.
72
  - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
73
  - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
74
 
75
- ### Single Inference
76
 
77
- You can test single inference using the following command. Before running the command, modify the config up to your need.
78
 
79
  ```bash
80
- # modify the config up to your need,
81
- # e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
82
- # nfe_step (larger takes more time to do more precise inference ode)
83
- # ode_method (switch to 'midpoint' for better compatibility with small nfe_step, )
84
- # ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
85
- python test_infer_single.py
86
- ```
87
- ### Speech Editing
88
-
89
- To test speech editing capabilities, use the following command.
90
 
91
- ```bash
92
- python test_infer_single_edit.py
93
  ```
94
 
95
  ### Gradio App
@@ -102,7 +92,7 @@ First, make sure you have the dependencies installed (`pip install -r requiremen
102
  pip install -r requirements_gradio.txt
103
  ```
104
 
105
- After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`):
106
 
107
  ```bash
108
  python gradio_app.py
@@ -120,6 +110,14 @@ Or launch a share link:
120
  python gradio_app.py --share
121
  ```
122
 
 
 
 
 
 
 
 
 
123
  ## Evaluation
124
 
125
  ### Prepare Test Datasets
@@ -127,7 +125,7 @@ python gradio_app.py --share
127
  1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
128
  2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
129
  3. Unzip the downloaded datasets and place them in the data/ directory.
130
- 4. Update the path for the test-clean data in `test_infer_batch.py`
131
  5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
132
 
133
  ### Batch Inference for Test Set
@@ -137,7 +135,7 @@ To run batch inference for evaluations, execute the following commands:
137
  ```bash
138
  # batch inference for evaluations
139
  accelerate config # if not set before
140
- bash test_infer_batch.sh
141
  ```
142
 
143
  ### Download Evaluation Model Checkpoints
 
58
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
59
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
60
  accelerate config
61
+ accelerate launch train.py
62
  ```
63
  An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
64
 
65
  ## Inference
66
 
67
+ To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
68
 
69
+ Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
70
  - To avoid possible inference failures, make sure you have seen through the following instructions.
71
+ - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
72
  - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
73
  - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
74
 
75
+ ### CLI Inference
76
 
77
+ Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
78
 
79
  ```bash
80
+ python inference-cli.py --model "F5-TTS" --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" --ref_text "Some call me nature, others call me mother nature." --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
 
 
 
 
 
 
 
 
 
81
 
82
+ python inference-cli.py --model "E2-TTS" --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" --ref_text "对,这就是我,万人敬仰的太乙真人。" --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
 
83
  ```
84
 
85
  ### Gradio App
 
92
  pip install -r requirements_gradio.txt
93
  ```
94
 
95
+ After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
96
 
97
  ```bash
98
  python gradio_app.py
 
110
  python gradio_app.py --share
111
  ```
112
 
113
+ ### Speech Editing
114
+
115
+ To test speech editing capabilities, use the following command.
116
+
117
+ ```bash
118
+ python speech_edit.py
119
+ ```
120
+
121
  ## Evaluation
122
 
123
  ### Prepare Test Datasets
 
125
  1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
126
  2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
127
  3. Unzip the downloaded datasets and place them in the data/ directory.
128
+ 4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
129
  5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
130
 
131
  ### Batch Inference for Test Set
 
135
  ```bash
136
  # batch inference for evaluations
137
  accelerate config # if not set before
138
+ bash scripts/eval_infer_batch.sh
139
  ```
140
 
141
  ### Download Evaluation Model Checkpoints
inference-cli.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import re
3
  import torch
4
  import torchaudio
@@ -16,10 +15,8 @@ from model.utils import (
16
  save_spectrogram,
17
  )
18
  from transformers import pipeline
19
- import librosa
20
- import click
21
  import soundfile as sf
22
- import tomllib
23
  import argparse
24
  import tqdm
25
  from pathlib import Path
@@ -42,19 +39,19 @@ parser.add_argument(
42
  )
43
  parser.add_argument(
44
  "-r",
45
- "--reference",
46
  type=str,
47
  help="Reference audio file < 15 seconds."
48
  )
49
  parser.add_argument(
50
  "-s",
51
- "--subtitle",
52
  type=str,
53
  help="Subtitle for the reference audio."
54
  )
55
  parser.add_argument(
56
  "-t",
57
- "--text",
58
  type=str,
59
  help="Text to generate.",
60
  )
@@ -70,11 +67,11 @@ parser.add_argument(
70
  )
71
  args = parser.parse_args()
72
 
73
- config = tomllib.load(open(args.config, "rb"))
74
 
75
- ref_audio = args.reference if args.reference else config["reference"]
76
- ref_text = args.subtitle if args.subtitle else config["subtitle"]
77
- gen_text = args.text if args.text else config["text"]
78
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
79
  exp_name = args.model if args.model else config["model"]
80
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
@@ -100,13 +97,6 @@ device = (
100
 
101
  print(f"Using {device} device")
102
 
103
- pipe = pipeline(
104
- "automatic-speech-recognition",
105
- model="openai/whisper-large-v3-turbo",
106
- torch_dtype=torch.float16,
107
- device=device,
108
- )
109
-
110
  # --------------------- Settings -------------------- #
111
 
112
  target_sample_rate = 24000
@@ -151,13 +141,6 @@ F5TTS_model_cfg = dict(
151
  )
152
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
153
 
154
- F5TTS_ema_model = load_model(
155
- "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
156
- )
157
- E2TTS_ema_model = load_model(
158
- "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
159
- )
160
-
161
  def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
162
  if len(text.encode('utf-8')) <= max_chars:
163
  return [text]
@@ -256,9 +239,9 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
256
 
257
  def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
258
  if exp_name == "F5-TTS":
259
- ema_model = F5TTS_ema_model
260
  elif exp_name == "E2-TTS":
261
- ema_model = E2TTS_ema_model
262
 
263
  audio, sr = torchaudio.load(ref_audio)
264
  if audio.shape[0] > 1:
@@ -363,6 +346,12 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
363
 
364
  if not ref_text.strip():
365
  print("No reference text provided, transcribing reference audio...")
 
 
 
 
 
 
366
  ref_text = pipe(
367
  ref_audio,
368
  chunk_length_s=30,
 
 
1
  import re
2
  import torch
3
  import torchaudio
 
15
  save_spectrogram,
16
  )
17
  from transformers import pipeline
 
 
18
  import soundfile as sf
19
+ import tomli
20
  import argparse
21
  import tqdm
22
  from pathlib import Path
 
39
  )
40
  parser.add_argument(
41
  "-r",
42
+ "--ref_audio",
43
  type=str,
44
  help="Reference audio file < 15 seconds."
45
  )
46
  parser.add_argument(
47
  "-s",
48
+ "--ref_text",
49
  type=str,
50
  help="Subtitle for the reference audio."
51
  )
52
  parser.add_argument(
53
  "-t",
54
+ "--gen_text",
55
  type=str,
56
  help="Text to generate.",
57
  )
 
67
  )
68
  args = parser.parse_args()
69
 
70
+ config = tomli.load(open(args.config, "rb"))
71
 
72
+ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
73
+ ref_text = args.ref_text if args.ref_text else config["ref_text"]
74
+ gen_text = args.gen_text if args.gen_text else config["gen_text"]
75
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
76
  exp_name = args.model if args.model else config["model"]
77
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
 
97
 
98
  print(f"Using {device} device")
99
 
 
 
 
 
 
 
 
100
  # --------------------- Settings -------------------- #
101
 
102
  target_sample_rate = 24000
 
141
  )
142
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
143
 
 
 
 
 
 
 
 
144
  def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
145
  if len(text.encode('utf-8')) <= max_chars:
146
  return [text]
 
239
 
240
  def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
241
  if exp_name == "F5-TTS":
242
+ ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
243
  elif exp_name == "E2-TTS":
244
+ ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
245
 
246
  audio, sr = torchaudio.load(ref_audio)
247
  if audio.shape[0] > 1:
 
346
 
347
  if not ref_text.strip():
348
  print("No reference text provided, transcribing reference audio...")
349
+ pipe = pipeline(
350
+ "automatic-speech-recognition",
351
+ model="openai/whisper-large-v3-turbo",
352
+ torch_dtype=torch.float16,
353
+ device=device,
354
+ )
355
  ref_text = pipe(
356
  ref_audio,
357
  chunk_length_s=30,
inference-cli.toml CHANGED
@@ -1,8 +1,8 @@
1
  # F5-TTS | E2-TTS
2
  model = "F5-TTS"
3
- reference = "tests/ref_audio/test_en_1_ref_short.wav"
4
  # If an empty "", transcribes the reference audio automatically.
5
- subtitle = "Some call me nature, others call me mother nature."
6
- text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
  remove_silence = true
8
  output_dir = "tests"
 
1
  # F5-TTS | E2-TTS
2
  model = "F5-TTS"
3
+ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
4
  # If an empty "", transcribes the reference audio automatically.
5
+ ref_text = "Some call me nature, others call me mother nature."
6
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
  remove_silence = true
8
  output_dir = "tests"
model/dataset.py CHANGED
@@ -188,7 +188,7 @@ def load_dataset(
188
  dataset_type: str = "CustomDataset",
189
  audio_type: str = "raw",
190
  mel_spec_kwargs: dict = dict()
191
- ) -> CustomDataset | HFDataset:
192
 
193
  print("Loading dataset ...")
194
 
 
188
  dataset_type: str = "CustomDataset",
189
  audio_type: str = "raw",
190
  mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset:
192
 
193
  print("Loading dataset ...")
194
 
requirements.txt CHANGED
@@ -1,16 +1,21 @@
1
  accelerate>=0.33.0
 
 
2
  datasets
3
  einops>=0.8.0
4
  einx>=0.3.0
5
  ema_pytorch>=0.5.2
6
  faster_whisper
7
  funasr
 
8
  jieba
9
  jiwer
10
  librosa
11
  matplotlib
 
12
  pypinyin
13
  safetensors
 
14
  # torch>=2.0
15
  # torchaudio>=2.3.0
16
  torchdiffeq
@@ -20,6 +25,4 @@ vocos
20
  wandb
21
  x_transformers>=1.31.14
22
  zhconv
23
- zhon
24
- pydub
25
- cached_path
 
1
  accelerate>=0.33.0
2
+ cached_path
3
+ click
4
  datasets
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
8
  faster_whisper
9
  funasr
10
+ gradio
11
  jieba
12
  jiwer
13
  librosa
14
  matplotlib
15
+ pydub
16
  pypinyin
17
  safetensors
18
+ soundfile
19
  # torch>=2.0
20
  # torchaudio>=2.3.0
21
  torchdiffeq
 
25
  wandb
26
  x_transformers>=1.31.14
27
  zhconv
28
+ zhon
 
 
requirements_gradio.txt DELETED
@@ -1,5 +0,0 @@
1
- cached_path
2
- click
3
- gradio
4
- pydub
5
- soundfile
 
 
 
 
 
 
test_infer_batch.py → scripts/eval_infer_batch.py RENAMED
@@ -1,4 +1,6 @@
1
- import os
 
 
2
  import time
3
  import random
4
  from tqdm import tqdm
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
  import time
5
  import random
6
  from tqdm import tqdm
scripts/eval_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
test_infer_single_edit.py → speech_edit.py RENAMED
File without changes
test_infer_batch.sh DELETED
@@ -1,13 +0,0 @@
1
- #!/bin/bash
2
-
3
- # e.g. F5-TTS, 16 NFE
4
- accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
- accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
- accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
-
8
- # e.g. Vanilla E2 TTS, 32 NFE
9
- accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
- accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
- accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
-
13
- # etc.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_infer_single.py DELETED
@@ -1,161 +0,0 @@
1
- import os
2
- import re
3
-
4
- import torch
5
- import torchaudio
6
- from einops import rearrange
7
- from vocos import Vocos
8
-
9
- from model import CFM, UNetT, DiT, MMDiT
10
- from model.utils import (
11
- load_checkpoint,
12
- get_tokenizer,
13
- convert_char_to_pinyin,
14
- save_spectrogram,
15
- )
16
-
17
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
-
19
-
20
- # --------------------- Dataset Settings -------------------- #
21
-
22
- target_sample_rate = 24000
23
- n_mel_channels = 100
24
- hop_length = 256
25
- target_rms = 0.1
26
-
27
- tokenizer = "pinyin"
28
- dataset_name = "Emilia_ZH_EN"
29
-
30
-
31
- # ---------------------- infer setting ---------------------- #
32
-
33
- seed = None # int | None
34
-
35
- exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
- ckpt_step = 1200000
37
-
38
- nfe_step = 32 # 16, 32
39
- cfg_strength = 2.
40
- ode_method = 'euler' # euler | midpoint
41
- sway_sampling_coef = -1.
42
- speed = 1.
43
- fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
44
-
45
- if exp_name == "F5TTS_Base":
46
- model_cls = DiT
47
- model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
48
-
49
- elif exp_name == "E2TTS_Base":
50
- model_cls = UNetT
51
- model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
52
-
53
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
54
- output_dir = "tests"
55
-
56
- ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
57
- ref_text = "Some call me nature, others call me mother nature."
58
- gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
59
-
60
- # ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
61
- # ref_text = "对,这就是我,万人敬仰的太乙真人。"
62
- # gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
63
-
64
-
65
- # -------------------------------------------------#
66
-
67
- use_ema = True
68
-
69
- if not os.path.exists(output_dir):
70
- os.makedirs(output_dir)
71
-
72
- # Vocoder model
73
- local = False
74
- if local:
75
- vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
76
- vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
77
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
78
- vocos.load_state_dict(state_dict)
79
- vocos.eval()
80
- else:
81
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
82
-
83
- # Tokenizer
84
- vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
85
-
86
- # Model
87
- model = CFM(
88
- transformer = model_cls(
89
- **model_cfg,
90
- text_num_embeds = vocab_size,
91
- mel_dim = n_mel_channels
92
- ),
93
- mel_spec_kwargs = dict(
94
- target_sample_rate = target_sample_rate,
95
- n_mel_channels = n_mel_channels,
96
- hop_length = hop_length,
97
- ),
98
- odeint_kwargs = dict(
99
- method = ode_method,
100
- ),
101
- vocab_char_map = vocab_char_map,
102
- ).to(device)
103
-
104
- model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
105
-
106
- # Audio
107
- audio, sr = torchaudio.load(ref_audio)
108
- if audio.shape[0] > 1:
109
- audio = torch.mean(audio, dim=0, keepdim=True)
110
- rms = torch.sqrt(torch.mean(torch.square(audio)))
111
- if rms < target_rms:
112
- audio = audio * target_rms / rms
113
- if sr != target_sample_rate:
114
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
115
- audio = resampler(audio)
116
- audio = audio.to(device)
117
-
118
- # Text
119
- if len(ref_text[-1].encode('utf-8')) == 1:
120
- ref_text = ref_text + " "
121
- text_list = [ref_text + gen_text]
122
- if tokenizer == "pinyin":
123
- final_text_list = convert_char_to_pinyin(text_list)
124
- else:
125
- final_text_list = [text_list]
126
- print(f"text : {text_list}")
127
- print(f"pinyin: {final_text_list}")
128
-
129
- # Duration
130
- ref_audio_len = audio.shape[-1] // hop_length
131
- if fix_duration is not None:
132
- duration = int(fix_duration * target_sample_rate / hop_length)
133
- else: # simple linear scale calcul
134
- zh_pause_punc = r"。,、;:?!"
135
- ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
136
- gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
137
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
138
-
139
- # Inference
140
- with torch.inference_mode():
141
- generated, trajectory = model.sample(
142
- cond = audio,
143
- text = final_text_list,
144
- duration = duration,
145
- steps = nfe_step,
146
- cfg_strength = cfg_strength,
147
- sway_sampling_coef = sway_sampling_coef,
148
- seed = seed,
149
- )
150
- print(f"Generated mel: {generated.shape}")
151
-
152
- # Final result
153
- generated = generated[:, ref_audio_len:, :]
154
- generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
155
- generated_wave = vocos.decode(generated_mel_spec.cpu())
156
- if rms < target_rms:
157
- generated_wave = generated_wave * rms / target_rms
158
-
159
- save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
160
- torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
161
- print(f"Generated wav: {generated_wave.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_train.py → train.py RENAMED
File without changes