Spaces:
Running
on
Zero
Running
on
Zero
fixed TTS class
Browse files
app.py
CHANGED
@@ -45,13 +45,14 @@ print(f"--- Initializing on device: {DEVICE} ---") # This will run when the Spac
|
|
45 |
# --- (Start of your model definitions - make sure this is complete from your previous code) ---
|
46 |
class Hyperparams:
|
47 |
seed = 42
|
|
|
48 |
csv_path = "path/to/metadata.csv"
|
49 |
wav_path = "path/to/wavs"
|
50 |
symbols = [
|
51 |
-
'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
|
52 |
-
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
|
53 |
-
't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü',
|
54 |
-
'’', '“', '”'
|
55 |
]
|
56 |
sr = 22050
|
57 |
n_fft = 2048
|
@@ -61,9 +62,9 @@ class Hyperparams:
|
|
61 |
mel_freq = 128
|
62 |
max_mel_time = 1024
|
63 |
power = 2.0
|
64 |
-
text_num_embeddings = 2*len(symbols)
|
65 |
embedding_size = 256
|
66 |
-
encoder_embedding_size = 512
|
67 |
dim_feedforward = 1024
|
68 |
postnet_embedding_size = 1024
|
69 |
encoder_kernel_size = 3
|
@@ -78,10 +79,10 @@ class Hyperparams:
|
|
78 |
|
79 |
hp = Hyperparams()
|
80 |
|
|
|
81 |
symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
|
82 |
def text_to_seq(text):
|
83 |
text = text.lower()
|
84 |
-
text = unidecode(text)
|
85 |
seq = []
|
86 |
for s in text:
|
87 |
_id = symbol_to_id.get(s, None)
|
@@ -90,8 +91,10 @@ def text_to_seq(text):
|
|
90 |
seq.append(symbol_to_id["EOS"])
|
91 |
return torch.IntTensor(seq)
|
92 |
|
|
|
93 |
spec_transform = torchaudio.transforms.Spectrogram(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=hp.power)
|
94 |
mel_scale_transform = torchaudio.transforms.MelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft)
|
|
|
95 |
mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
|
96 |
griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length).to(DEVICE)
|
97 |
|
@@ -102,7 +105,7 @@ def pow_to_db_mel_spec(mel_spec):
|
|
102 |
|
103 |
def db_to_power_mel_spec(mel_spec):
|
104 |
mel_spec = mel_spec*hp.scale_db
|
105 |
-
mel_spec = torchaudio.functional.DB_to_amplitude(mel_spec, ref=hp.ampl_ref, power=hp.ampl_power)
|
106 |
return mel_spec
|
107 |
|
108 |
def inverse_mel_spec_to_wav(mel_spec):
|
@@ -115,8 +118,9 @@ def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> to
|
|
115 |
ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
|
116 |
range_tensor = ones.cumsum(dim=1)
|
117 |
return sequence_lengths.unsqueeze(1) >= range_tensor
|
118 |
-
|
119 |
-
|
|
|
120 |
def __init__(self):
|
121 |
super(EncoderBlock, self).__init__()
|
122 |
self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
@@ -131,8 +135,8 @@ class EncoderBlock(nn.Module): # Your EncoderBlock definition
|
|
131 |
x_out = self.norm_1(x)
|
132 |
x_out, _ = self.attn(query=x_out, key=x_out, value=x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
133 |
x_out = self.dropout_1(x_out)
|
134 |
-
x = x + x_out
|
135 |
-
x_out = self.norm_2(x)
|
136 |
x_out = self.linear_1(x_out)
|
137 |
x_out = F.relu(x_out)
|
138 |
x_out = self.dropout_2(x_out)
|
@@ -141,14 +145,14 @@ class EncoderBlock(nn.Module): # Your EncoderBlock definition
|
|
141 |
x = x + x_out
|
142 |
return x
|
143 |
|
144 |
-
class DecoderBlock(nn.Module):
|
145 |
def __init__(self):
|
146 |
super(DecoderBlock, self).__init__()
|
147 |
self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
148 |
self.self_attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
|
149 |
self.dropout_1 = torch.nn.Dropout(0.1)
|
150 |
self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
151 |
-
self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
|
152 |
self.dropout_2 = torch.nn.Dropout(0.1)
|
153 |
self.norm_3 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
154 |
self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
|
@@ -170,7 +174,7 @@ class DecoderBlock(nn.Module): # Your DecoderBlock definition
|
|
170 |
x = self.norm_3(x + x_out)
|
171 |
return x
|
172 |
|
173 |
-
class EncoderPreNet(nn.Module):
|
174 |
def __init__(self):
|
175 |
super(EncoderPreNet, self).__init__()
|
176 |
self.embedding = nn.Embedding(num_embeddings=hp.text_num_embeddings, embedding_dim=hp.encoder_embedding_size)
|
@@ -184,22 +188,28 @@ class EncoderPreNet(nn.Module): # Your EncoderPreNet definition
|
|
184 |
self.dropout_2 = torch.nn.Dropout(0.5)
|
185 |
self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
|
186 |
self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size)
|
187 |
-
self.dropout_3 = torch.nn.Dropout(0.5)
|
188 |
def forward(self, text):
|
189 |
x = self.embedding(text)
|
190 |
x = self.linear_1(x)
|
191 |
x = x.transpose(2, 1)
|
192 |
x = self.conv_1(x)
|
193 |
-
x = self.bn_1(x)
|
|
|
|
|
194 |
x = self.conv_2(x)
|
195 |
-
x = self.bn_2(x)
|
|
|
|
|
196 |
x = self.conv_3(x)
|
197 |
-
x = self.bn_3(x)
|
|
|
|
|
198 |
x = x.transpose(1, 2)
|
199 |
x = self.linear_2(x)
|
200 |
return x
|
201 |
|
202 |
-
class PostNet(nn.Module):
|
203 |
def __init__(self):
|
204 |
super(PostNet, self).__init__()
|
205 |
self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
|
@@ -221,18 +231,23 @@ class PostNet(nn.Module): # Your PostNet definition
|
|
221 |
self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
|
222 |
self.dropout_6 = torch.nn.Dropout(0.5)
|
223 |
def forward(self, x):
|
224 |
-
x_orig = x # Store original for residual connection if postnet predicts residual
|
225 |
x = x.transpose(2, 1)
|
226 |
-
x = self.conv_1(x)
|
227 |
-
x = self.
|
228 |
-
x = self.
|
229 |
-
x = self.
|
230 |
-
x = self.
|
231 |
-
x = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
x = x.transpose(1, 2)
|
233 |
-
return x
|
234 |
|
235 |
-
class DecoderPreNet(nn.Module):
|
236 |
def __init__(self):
|
237 |
super(DecoderPreNet, self).__init__()
|
238 |
self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
|
@@ -240,13 +255,13 @@ class DecoderPreNet(nn.Module): # Your DecoderPreNet definition
|
|
240 |
def forward(self, x):
|
241 |
x = self.linear_1(x)
|
242 |
x = F.relu(x)
|
243 |
-
x = F.dropout(x, p=0.5, training=
|
244 |
x = self.linear_2(x)
|
245 |
-
x = F.relu(x)
|
246 |
-
x = F.dropout(x, p=0.5, training=
|
247 |
-
return x
|
248 |
|
249 |
-
class TransformerTTS(nn.Module):
|
250 |
def __init__(self, device=DEVICE):
|
251 |
super(TransformerTTS, self).__init__()
|
252 |
self.encoder_prenet = EncoderPreNet()
|
@@ -259,132 +274,56 @@ class TransformerTTS(nn.Module): # Your TransformerTTS definition
|
|
259 |
self.decoder_block_1 = DecoderBlock()
|
260 |
self.decoder_block_2 = DecoderBlock()
|
261 |
self.decoder_block_3 = DecoderBlock()
|
262 |
-
self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
|
263 |
-
self.linear_2 = nn.Linear(hp.embedding_size, 1)
|
264 |
self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
265 |
-
|
266 |
-
|
267 |
-
def forward(self, text, text_len, mel, mel_len): # For training/teacher-forcing
|
268 |
-
# ... (Your detailed forward pass for training, with all masks)
|
269 |
N = text.shape[0]; S = text.shape[1]; TIME = mel.shape[1]
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
text_x = self.
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
text_x = self.encoder_block_3(text_x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
|
287 |
-
memory = self.norm_memory(text_x)
|
288 |
-
|
289 |
-
mel_x = self.decoder_prenet(mel)
|
290 |
-
mel_time_dim = mel_x.shape[1]
|
291 |
-
mel_x = mel_x + pos_codes[:mel_time_dim]
|
292 |
-
|
293 |
-
mel_x = self.decoder_block_1(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
|
294 |
-
mel_x = self.decoder_block_2(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
|
295 |
-
mel_x = self.decoder_block_3(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
|
296 |
-
|
297 |
mel_linear = self.linear_1(mel_x)
|
298 |
-
|
299 |
-
mel_postnet = mel_linear +
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0.0)
|
307 |
-
# Ensure stop_token is [N, TIME]
|
308 |
-
stop_token = stop_token.masked_fill(tgt_key_padding_mask.unsqueeze(-1) if stop_token.dim() == 3 else tgt_key_padding_mask, 1e3)
|
309 |
-
if stop_token.dim() == 3 and stop_token.shape[2] == 1:
|
310 |
-
stop_token = stop_token.squeeze(-1)
|
311 |
-
|
312 |
-
|
313 |
-
return mel_postnet, mel_linear, stop_token
|
314 |
-
|
315 |
|
316 |
@torch.no_grad()
|
317 |
-
def inference(self, text, max_length=800, stop_token_threshold=0.5):
|
318 |
-
self.eval()
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
encoder_output = self.encoder_prenet(text)
|
332 |
-
pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
|
333 |
-
text_s_dim = encoder_output.shape[1]
|
334 |
-
encoder_output = encoder_output + pos_codes[:text_s_dim]
|
335 |
-
|
336 |
-
encoder_output = self.encoder_block_1(encoder_output, key_padding_mask=src_key_padding_mask_inf)
|
337 |
-
encoder_output = self.encoder_block_2(encoder_output, key_padding_mask=src_key_padding_mask_inf)
|
338 |
-
encoder_output = self.encoder_block_3(encoder_output, key_padding_mask=src_key_padding_mask_inf)
|
339 |
-
memory = self.norm_memory(encoder_output)
|
340 |
-
|
341 |
-
# Decoder pass (iterative)
|
342 |
-
mel_input = torch.zeros((N, 1, hp.mel_freq), device=current_device) # SOS frame
|
343 |
-
generated_mel_frames = []
|
344 |
-
|
345 |
-
for i in range(max_length):
|
346 |
-
mel_lengths_inf = torch.tensor([mel_input.shape[1]], device=current_device)
|
347 |
-
# For decoder self-attention, causal mask is needed
|
348 |
-
tgt_mask_inf = torch.zeros((mel_input.shape[1], mel_input.shape[1]), device=current_device).masked_fill(
|
349 |
-
torch.triu(torch.full((mel_input.shape[1], mel_input.shape[1]), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf")
|
350 |
-
)
|
351 |
-
# Decoder input padding mask (all False as we build it frame by frame, no padding yet)
|
352 |
-
tgt_key_padding_mask_inf = torch.zeros((N, mel_input.shape[1]), device=current_device, dtype=torch.bool)
|
353 |
-
|
354 |
-
|
355 |
-
mel_x = self.decoder_prenet(mel_input)
|
356 |
-
mel_time_dim = mel_input.shape[1]
|
357 |
-
mel_x = mel_x + pos_codes[:mel_time_dim] # Positional encoding for current mel sequence
|
358 |
-
|
359 |
-
mel_x = self.decoder_block_1(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
|
360 |
-
mel_x = self.decoder_block_2(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
|
361 |
-
mel_x = self.decoder_block_3(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
|
362 |
-
|
363 |
-
mel_linear_step = self.linear_1(mel_x[:, -1:, :]) # Predict only for the last frame
|
364 |
-
mel_postnet_residual_step = self.postnet(mel_linear_step)
|
365 |
-
current_mel_frame = mel_linear_step + mel_postnet_residual_step
|
366 |
-
|
367 |
-
generated_mel_frames.append(current_mel_frame)
|
368 |
-
mel_input = torch.cat([mel_input, current_mel_frame], dim=1) # Append to input for next step
|
369 |
-
|
370 |
-
# Stop token prediction (based on the last decoder output before linear to mel)
|
371 |
-
stop_token_logit = self.linear_2(mel_x[:, -1:, :]) # Stop token from last frame's decoder hidden state
|
372 |
-
stop_token_prob = torch.sigmoid(stop_token_logit.squeeze())
|
373 |
-
|
374 |
-
if stop_token_prob > stop_token_threshold:
|
375 |
-
# print(f"Stop token threshold reached at step {i+1}")
|
376 |
-
break
|
377 |
-
if mel_input.shape[1] > hp.max_mel_time -1: # Safety break based on max_mel_time
|
378 |
-
# print(f"Max mel time {hp.max_mel_time} almost reached.")
|
379 |
break
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
return torch.zeros((N, 0, hp.mel_freq), device=current_device) # Return empty tensor
|
385 |
-
|
386 |
-
final_mel_output = torch.cat(generated_mel_frames, dim=1)
|
387 |
-
return final_mel_output # Removed stop_token_outputs as it's not used by caller
|
388 |
# --- (End of your model definitions) ---
|
389 |
|
390 |
# --- Part 2: Model Loading ---
|
@@ -395,6 +334,7 @@ MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
|
|
395 |
|
396 |
|
397 |
|
|
|
398 |
# Wrap model loading in a function to clearly see when it happens or to potentially delay it.
|
399 |
# For Spaces, global loading is fine and preferred as it happens once.
|
400 |
print("--- Starting Model Loading ---")
|
|
|
45 |
# --- (Start of your model definitions - make sure this is complete from your previous code) ---
|
46 |
class Hyperparams:
|
47 |
seed = 42
|
48 |
+
# We won't use these dataset paths, but keep them for hp object integrity
|
49 |
csv_path = "path/to/metadata.csv"
|
50 |
wav_path = "path/to/wavs"
|
51 |
symbols = [
|
52 |
+
'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
|
53 |
+
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
|
54 |
+
't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü',
|
55 |
+
'’', '“', '”'
|
56 |
]
|
57 |
sr = 22050
|
58 |
n_fft = 2048
|
|
|
62 |
mel_freq = 128
|
63 |
max_mel_time = 1024
|
64 |
power = 2.0
|
65 |
+
text_num_embeddings = 2*len(symbols)
|
66 |
embedding_size = 256
|
67 |
+
encoder_embedding_size = 512
|
68 |
dim_feedforward = 1024
|
69 |
postnet_embedding_size = 1024
|
70 |
encoder_kernel_size = 3
|
|
|
79 |
|
80 |
hp = Hyperparams()
|
81 |
|
82 |
+
# Text to Sequence
|
83 |
symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
|
84 |
def text_to_seq(text):
|
85 |
text = text.lower()
|
|
|
86 |
seq = []
|
87 |
for s in text:
|
88 |
_id = symbol_to_id.get(s, None)
|
|
|
91 |
seq.append(symbol_to_id["EOS"])
|
92 |
return torch.IntTensor(seq)
|
93 |
|
94 |
+
# Audio Processing
|
95 |
spec_transform = torchaudio.transforms.Spectrogram(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=hp.power)
|
96 |
mel_scale_transform = torchaudio.transforms.MelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft)
|
97 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
98 |
mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
|
99 |
griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length).to(DEVICE)
|
100 |
|
|
|
105 |
|
106 |
def db_to_power_mel_spec(mel_spec):
|
107 |
mel_spec = mel_spec*hp.scale_db
|
108 |
+
mel_spec = torchaudio.functional.DB_to_amplitude(mel_spec, ref=hp.ampl_ref, power=hp.ampl_power)
|
109 |
return mel_spec
|
110 |
|
111 |
def inverse_mel_spec_to_wav(mel_spec):
|
|
|
118 |
ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
|
119 |
range_tensor = ones.cumsum(dim=1)
|
120 |
return sequence_lengths.unsqueeze(1) >= range_tensor
|
121 |
+
|
122 |
+
# --- TransformerTTS Model Architecture (Copied from notebook)
|
123 |
+
class EncoderBlock(nn.Module):
|
124 |
def __init__(self):
|
125 |
super(EncoderBlock, self).__init__()
|
126 |
self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
|
|
135 |
x_out = self.norm_1(x)
|
136 |
x_out, _ = self.attn(query=x_out, key=x_out, value=x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
137 |
x_out = self.dropout_1(x_out)
|
138 |
+
x = x + x_out
|
139 |
+
x_out = self.norm_2(x)
|
140 |
x_out = self.linear_1(x_out)
|
141 |
x_out = F.relu(x_out)
|
142 |
x_out = self.dropout_2(x_out)
|
|
|
145 |
x = x + x_out
|
146 |
return x
|
147 |
|
148 |
+
class DecoderBlock(nn.Module):
|
149 |
def __init__(self):
|
150 |
super(DecoderBlock, self).__init__()
|
151 |
self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
152 |
self.self_attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
|
153 |
self.dropout_1 = torch.nn.Dropout(0.1)
|
154 |
self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
155 |
+
self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
|
156 |
self.dropout_2 = torch.nn.Dropout(0.1)
|
157 |
self.norm_3 = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
158 |
self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
|
|
|
174 |
x = self.norm_3(x + x_out)
|
175 |
return x
|
176 |
|
177 |
+
class EncoderPreNet(nn.Module):
|
178 |
def __init__(self):
|
179 |
super(EncoderPreNet, self).__init__()
|
180 |
self.embedding = nn.Embedding(num_embeddings=hp.text_num_embeddings, embedding_dim=hp.encoder_embedding_size)
|
|
|
188 |
self.dropout_2 = torch.nn.Dropout(0.5)
|
189 |
self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
|
190 |
self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size)
|
191 |
+
self.dropout_3 = torch.nn.Dropout(0.5)
|
192 |
def forward(self, text):
|
193 |
x = self.embedding(text)
|
194 |
x = self.linear_1(x)
|
195 |
x = x.transpose(2, 1)
|
196 |
x = self.conv_1(x)
|
197 |
+
x = self.bn_1(x)
|
198 |
+
x = F.relu(x)
|
199 |
+
x = self.dropout_1(x)
|
200 |
x = self.conv_2(x)
|
201 |
+
x = self.bn_2(x)
|
202 |
+
x = F.relu(x)
|
203 |
+
x = self.dropout_2(x)
|
204 |
x = self.conv_3(x)
|
205 |
+
x = self.bn_3(x)
|
206 |
+
x = F.relu(x)
|
207 |
+
x = self.dropout_3(x)
|
208 |
x = x.transpose(1, 2)
|
209 |
x = self.linear_2(x)
|
210 |
return x
|
211 |
|
212 |
+
class PostNet(nn.Module):
|
213 |
def __init__(self):
|
214 |
super(PostNet, self).__init__()
|
215 |
self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
|
|
|
231 |
self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
|
232 |
self.dropout_6 = torch.nn.Dropout(0.5)
|
233 |
def forward(self, x):
|
|
|
234 |
x = x.transpose(2, 1)
|
235 |
+
x = self.conv_1(x)
|
236 |
+
x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
|
237 |
+
x = self.conv_2(x)
|
238 |
+
x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
|
239 |
+
x = self.conv_3(x)
|
240 |
+
x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
|
241 |
+
x = self.conv_4(x)
|
242 |
+
x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
|
243 |
+
x = self.conv_5(x)
|
244 |
+
x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
|
245 |
+
x = self.conv_6(x)
|
246 |
+
x = self.bn_6(x); x = self.dropout_6(x)
|
247 |
x = x.transpose(1, 2)
|
248 |
+
return x
|
249 |
|
250 |
+
class DecoderPreNet(nn.Module):
|
251 |
def __init__(self):
|
252 |
super(DecoderPreNet, self).__init__()
|
253 |
self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
|
|
|
255 |
def forward(self, x):
|
256 |
x = self.linear_1(x)
|
257 |
x = F.relu(x)
|
258 |
+
x = F.dropout(x, p=0.5, training=True)
|
259 |
x = self.linear_2(x)
|
260 |
+
x = F.relu(x)
|
261 |
+
x = F.dropout(x, p=0.5, training=True)
|
262 |
+
return x
|
263 |
|
264 |
+
class TransformerTTS(nn.Module):
|
265 |
def __init__(self, device=DEVICE):
|
266 |
super(TransformerTTS, self).__init__()
|
267 |
self.encoder_prenet = EncoderPreNet()
|
|
|
274 |
self.decoder_block_1 = DecoderBlock()
|
275 |
self.decoder_block_2 = DecoderBlock()
|
276 |
self.decoder_block_3 = DecoderBlock()
|
277 |
+
self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
|
278 |
+
self.linear_2 = nn.Linear(hp.embedding_size, 1)
|
279 |
self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)
|
280 |
+
def forward(self, text, text_len, mel, mel_len):
|
|
|
|
|
|
|
281 |
N = text.shape[0]; S = text.shape[1]; TIME = mel.shape[1]
|
282 |
+
self.src_key_padding_mask = torch.zeros((N, S), device=text.device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S), float("-inf"))
|
283 |
+
self.src_mask = torch.zeros((S, S), device=text.device).masked_fill(torch.triu(torch.full((S, S), True, dtype=torch.bool), diagonal=1).to(text.device), float("-inf"))
|
284 |
+
self.tgt_key_padding_mask = torch.zeros((N, TIME), device=mel.device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME), float("-inf"))
|
285 |
+
self.tgt_mask = torch.zeros((TIME, TIME), device=mel.device).masked_fill(torch.triu(torch.full((TIME, TIME), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
|
286 |
+
self.memory_mask = torch.zeros((TIME, S), device=mel.device).masked_fill(torch.triu(torch.full((TIME, S), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
|
287 |
+
text_x = self.encoder_prenet(text)
|
288 |
+
pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time).to(mel.device))
|
289 |
+
S = text_x.shape[1]; text_x = text_x + pos_codes[:S]
|
290 |
+
text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
|
291 |
+
text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
|
292 |
+
text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
|
293 |
+
text_x = self.norm_memory(text_x)
|
294 |
+
mel_x = self.decoder_prenet(mel); mel_x = mel_x + pos_codes[:TIME]
|
295 |
+
mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
|
296 |
+
mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
|
297 |
+
mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
mel_linear = self.linear_1(mel_x)
|
299 |
+
mel_postnet = self.postnet(mel_linear)
|
300 |
+
mel_postnet = mel_linear + mel_postnet
|
301 |
+
stop_token = self.linear_2(mel_x)
|
302 |
+
bool_mel_mask = self.tgt_key_padding_mask.ne(0).unsqueeze(-1).repeat(1, 1, hp.mel_freq)
|
303 |
+
mel_linear = mel_linear.masked_fill(bool_mel_mask, 0)
|
304 |
+
mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0)
|
305 |
+
stop_token = stop_token.masked_fill(bool_mel_mask[:, :, 0].unsqueeze(-1), 1e3).squeeze(2)
|
306 |
+
return mel_postnet, mel_linear, stop_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
@torch.no_grad()
|
309 |
+
def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=True):
|
310 |
+
self.eval(); self.train(False)
|
311 |
+
text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
|
312 |
+
N = 1
|
313 |
+
SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
|
314 |
+
mel_padded = SOS
|
315 |
+
mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
|
316 |
+
stop_token_outputs = torch.FloatTensor([]).to(text.device)
|
317 |
+
iters = range(max_length)
|
318 |
+
for _ in iters:
|
319 |
+
mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
|
320 |
+
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
|
321 |
+
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
break
|
323 |
+
else:
|
324 |
+
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|
325 |
+
mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
|
326 |
+
return mel_postnet, stop_token_outputs
|
|
|
|
|
|
|
|
|
327 |
# --- (End of your model definitions) ---
|
328 |
|
329 |
# --- Part 2: Model Loading ---
|
|
|
334 |
|
335 |
|
336 |
|
337 |
+
|
338 |
# Wrap model loading in a function to clearly see when it happens or to potentially delay it.
|
339 |
# For Spaces, global loading is fine and preferred as it happens once.
|
340 |
print("--- Starting Model Loading ---")
|