SWivid commited on
Commit
a621c22
·
1 Parent(s): 39ce201

add speech edit test script

Browse files
Files changed (3) hide show
  1. README.md +5 -1
  2. model/cfm.py +3 -0
  3. test_infer_single_edit.py +185 -0
README.md CHANGED
@@ -30,13 +30,16 @@ accelerate launch test_train.py
30
  ## Inference
31
  Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
32
  ```bash
33
- # single test inference
34
  # modify the config up to your need,
35
  # e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
36
  # nfe_step (larger takes more time to do more precise inference ode)
37
  # ode_method (switch to 'midpoint' for better compatibility with small nfe_step, )
38
  # ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
39
  python test_infer_single.py
 
 
 
40
  ```
41
 
42
 
@@ -77,3 +80,4 @@ python scripts/eval_librispeech_test_clean.py
77
  - <a href="https://arxiv.org/abs/2403.03206">SD3</a> & <a href="https://github.com/huggingface/diffusers">Huggingface diffusers</a> DiT and MMDiT code structure
78
  - <a href="https://github.com/modelscope/FunASR">FunASR</a>, <a href="https://github.com/SYSTRAN/faster-whisper">faster-whisper</a> & <a href="https://github.com/microsoft/UniSpeech">UniSpeech</a> for evaluation tools
79
  - <a href="https://github.com/rtqichen/torchdiffeq">torchdiffeq</a> as ODE solver, <a href="https://huggingface.co/charactr/vocos-mel-24khz">Vocos</a> as vocoder
 
 
30
  ## Inference
31
  Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
32
  ```bash
33
+ # test single inference
34
  # modify the config up to your need,
35
  # e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
36
  # nfe_step (larger takes more time to do more precise inference ode)
37
  # ode_method (switch to 'midpoint' for better compatibility with small nfe_step, )
38
  # ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
39
  python test_infer_single.py
40
+
41
+ # test speech edit
42
+ python test_infer_single_edit.py
43
  ```
44
 
45
 
 
80
  - <a href="https://arxiv.org/abs/2403.03206">SD3</a> & <a href="https://github.com/huggingface/diffusers">Huggingface diffusers</a> DiT and MMDiT code structure
81
  - <a href="https://github.com/modelscope/FunASR">FunASR</a>, <a href="https://github.com/SYSTRAN/faster-whisper">faster-whisper</a> & <a href="https://github.com/microsoft/UniSpeech">UniSpeech</a> for evaluation tools
82
  - <a href="https://github.com/rtqichen/torchdiffeq">torchdiffeq</a> as ODE solver, <a href="https://huggingface.co/charactr/vocos-mel-24khz">Vocos</a> as vocoder
83
+ - <a href="https://github.com/MahmoudAshraf97/ctc-forced-aligner">ctc-forced-aligner</a> for speech edit test
model/cfm.py CHANGED
@@ -95,6 +95,7 @@ class CFM(nn.Module):
95
  no_ref_audio = False,
96
  duplicate_test = False,
97
  t_inter = 0.1,
 
98
  ):
99
  self.eval()
100
 
@@ -125,6 +126,8 @@ class CFM(nn.Module):
125
  # duration
126
 
127
  cond_mask = lens_to_mask(lens)
 
 
128
 
129
  if isinstance(duration, int):
130
  duration = torch.full((batch,), duration, device = device, dtype = torch.long)
 
95
  no_ref_audio = False,
96
  duplicate_test = False,
97
  t_inter = 0.1,
98
+ edit_mask = None,
99
  ):
100
  self.eval()
101
 
 
126
  # duration
127
 
128
  cond_mask = lens_to_mask(lens)
129
+ if edit_mask is not None:
130
+ cond_mask = cond_mask & edit_mask
131
 
132
  if isinstance(duration, int):
133
  duration = torch.full((batch,), duration, device = device, dtype = torch.long)
test_infer_single_edit.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from ema_pytorch import EMA
8
+ from vocos import Vocos
9
+
10
+ from model import CFM, UNetT, DiT, MMDiT
11
+ from model.utils import (
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+
44
+ if exp_name == "F5TTS_Base":
45
+ model_cls = DiT
46
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
47
+
48
+ elif exp_name == "E2TTS_Base":
49
+ model_cls = UNetT
50
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
+
52
+ checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
53
+ output_dir = "tests"
54
+
55
+ # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
56
+ # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
57
+ # [write the origin_text into a file, e.g. tests/test_edit.txt]
58
+ # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
59
+ # [result will be saved at same path of audio file]
60
+ # [--language "zho" for Chinese, "eng" for English]
61
+ # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
62
+
63
+ audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
64
+ origin_text = "Some call me nature, others call me mother nature."
65
+ target_text = "Some call me optimist, others call me realist."
66
+ parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
67
+ fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
68
+
69
+ # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
70
+ # origin_text = "对,这就是我,万人敬仰的太乙真人。"
71
+ # target_text = "对,那就是你,万人敬仰的太白金星。"
72
+ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
73
+ # fix_duration = None # use origin text duration
74
+
75
+
76
+ # -------------------------------------------------#
77
+
78
+ use_ema = True
79
+
80
+ if not os.path.exists(output_dir):
81
+ os.makedirs(output_dir)
82
+
83
+ # Vocoder model
84
+ local = False
85
+ if local:
86
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
89
+ vocos.load_state_dict(state_dict)
90
+ vocos.eval()
91
+ else:
92
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
93
+
94
+ # Tokenizer
95
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96
+
97
+ # Model
98
+ model = CFM(
99
+ transformer = model_cls(
100
+ **model_cfg,
101
+ text_num_embeds = vocab_size,
102
+ mel_dim = n_mel_channels
103
+ ),
104
+ mel_spec_kwargs = dict(
105
+ target_sample_rate = target_sample_rate,
106
+ n_mel_channels = n_mel_channels,
107
+ hop_length = hop_length,
108
+ ),
109
+ odeint_kwargs = dict(
110
+ method = ode_method,
111
+ ),
112
+ vocab_char_map = vocab_char_map,
113
+ ).to(device)
114
+
115
+ if use_ema == True:
116
+ ema_model = EMA(model, include_online_model = False).to(device)
117
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
118
+ ema_model.copy_params_from_ema_to_model()
119
+ else:
120
+ model.load_state_dict(checkpoint['model_state_dict'])
121
+
122
+ # Audio
123
+ audio, sr = torchaudio.load(audio_to_edit)
124
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
125
+ if rms < target_rms:
126
+ audio = audio * target_rms / rms
127
+ if sr != target_sample_rate:
128
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
129
+ audio = resampler(audio)
130
+ offset = 0
131
+ audio_ = torch.zeros(1, 0)
132
+ edit_mask = torch.zeros(1, 0, dtype=torch.bool)
133
+ for part in parts_to_edit:
134
+ start, end = part
135
+ part_dur = end - start if fix_duration is None else fix_duration.pop(0)
136
+ part_dur = part_dur * target_sample_rate
137
+ start = start * target_sample_rate
138
+ audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
139
+ edit_mask = torch.cat((edit_mask,
140
+ torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
141
+ torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
142
+ ), dim = -1)
143
+ offset = end * target_sample_rate
144
+ # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
145
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
146
+ audio = audio.to(device)
147
+ edit_mask = edit_mask.to(device)
148
+
149
+ # Text
150
+ text_list = [target_text]
151
+ if tokenizer == "pinyin":
152
+ final_text_list = convert_char_to_pinyin(text_list)
153
+ else:
154
+ final_text_list = [text_list]
155
+ print(f"text : {text_list}")
156
+ print(f"pinyin: {final_text_list}")
157
+
158
+ # Duration
159
+ ref_audio_len = 0
160
+ duration = audio.shape[-1] // hop_length
161
+
162
+ # Inference
163
+ with torch.inference_mode():
164
+ generated, trajectory = model.sample(
165
+ cond = audio,
166
+ text = final_text_list,
167
+ duration = duration,
168
+ steps = nfe_step,
169
+ cfg_strength = cfg_strength,
170
+ sway_sampling_coef = sway_sampling_coef,
171
+ seed = seed,
172
+ edit_mask = edit_mask,
173
+ )
174
+ print(f"Generated mel: {generated.shape}")
175
+
176
+ # Final result
177
+ generated = generated[:, ref_audio_len:, :]
178
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
179
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
180
+ if rms < target_rms:
181
+ generated_wave = generated_wave * rms / target_rms
182
+
183
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
184
+ torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
185
+ print(f"Generated wav: {generated_wave.shape}")