MoHamdyy commited on
Commit
574f683
·
1 Parent(s): eb8f826

fixed TTS class

Browse files
Files changed (1) hide show
  1. app.py +94 -154
app.py CHANGED
@@ -45,13 +45,14 @@ print(f"--- Initializing on device: {DEVICE} ---") # This will run when the Spac
45
  # --- (Start of your model definitions - make sure this is complete from your previous code) ---
46
  class Hyperparams:
47
  seed = 42
 
48
  csv_path = "path/to/metadata.csv"
49
  wav_path = "path/to/wavs"
50
  symbols = [
51
- 'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
52
- 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
53
- 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü',
54
- '’', '“', '”'
55
  ]
56
  sr = 22050
57
  n_fft = 2048
@@ -61,9 +62,9 @@ class Hyperparams:
61
  mel_freq = 128
62
  max_mel_time = 1024
63
  power = 2.0
64
- text_num_embeddings = 2*len(symbols)
65
  embedding_size = 256
66
- encoder_embedding_size = 512
67
  dim_feedforward = 1024
68
  postnet_embedding_size = 1024
69
  encoder_kernel_size = 3
@@ -78,10 +79,10 @@ class Hyperparams:
78
 
79
  hp = Hyperparams()
80
 
 
81
  symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
82
  def text_to_seq(text):
83
  text = text.lower()
84
- text = unidecode(text)
85
  seq = []
86
  for s in text:
87
  _id = symbol_to_id.get(s, None)
@@ -90,8 +91,10 @@ def text_to_seq(text):
90
  seq.append(symbol_to_id["EOS"])
91
  return torch.IntTensor(seq)
92
 
 
93
  spec_transform = torchaudio.transforms.Spectrogram(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=hp.power)
94
  mel_scale_transform = torchaudio.transforms.MelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft)
 
95
  mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
96
  griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length).to(DEVICE)
97
 
@@ -102,7 +105,7 @@ def pow_to_db_mel_spec(mel_spec):
102
 
103
  def db_to_power_mel_spec(mel_spec):
104
  mel_spec = mel_spec*hp.scale_db
105
- mel_spec = torchaudio.functional.DB_to_amplitude(mel_spec, ref=hp.ampl_ref, power=hp.ampl_power)
106
  return mel_spec
107
 
108
  def inverse_mel_spec_to_wav(mel_spec):
@@ -115,8 +118,9 @@ def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> to
115
  ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
116
  range_tensor = ones.cumsum(dim=1)
117
  return sequence_lengths.unsqueeze(1) >= range_tensor
118
-
119
- class EncoderBlock(nn.Module): # Your EncoderBlock definition
 
120
  def __init__(self):
121
  super(EncoderBlock, self).__init__()
122
  self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
@@ -131,8 +135,8 @@ class EncoderBlock(nn.Module): # Your EncoderBlock definition
131
  x_out = self.norm_1(x)
132
  x_out, _ = self.attn(query=x_out, key=x_out, value=x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
133
  x_out = self.dropout_1(x_out)
134
- x = x + x_out
135
- x_out = self.norm_2(x)
136
  x_out = self.linear_1(x_out)
137
  x_out = F.relu(x_out)
138
  x_out = self.dropout_2(x_out)
@@ -141,14 +145,14 @@ class EncoderBlock(nn.Module): # Your EncoderBlock definition
141
  x = x + x_out
142
  return x
143
 
144
- class DecoderBlock(nn.Module): # Your DecoderBlock definition
145
  def __init__(self):
146
  super(DecoderBlock, self).__init__()
147
  self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
148
  self.self_attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
149
  self.dropout_1 = torch.nn.Dropout(0.1)
150
  self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
151
- self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
152
  self.dropout_2 = torch.nn.Dropout(0.1)
153
  self.norm_3 = nn.LayerNorm(normalized_shape=hp.embedding_size)
154
  self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
@@ -170,7 +174,7 @@ class DecoderBlock(nn.Module): # Your DecoderBlock definition
170
  x = self.norm_3(x + x_out)
171
  return x
172
 
173
- class EncoderPreNet(nn.Module): # Your EncoderPreNet definition
174
  def __init__(self):
175
  super(EncoderPreNet, self).__init__()
176
  self.embedding = nn.Embedding(num_embeddings=hp.text_num_embeddings, embedding_dim=hp.encoder_embedding_size)
@@ -184,22 +188,28 @@ class EncoderPreNet(nn.Module): # Your EncoderPreNet definition
184
  self.dropout_2 = torch.nn.Dropout(0.5)
185
  self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
186
  self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size)
187
- self.dropout_3 = torch.nn.Dropout(0.5)
188
  def forward(self, text):
189
  x = self.embedding(text)
190
  x = self.linear_1(x)
191
  x = x.transpose(2, 1)
192
  x = self.conv_1(x)
193
- x = self.bn_1(x); x = F.relu(x); x = self.dropout_1(x)
 
 
194
  x = self.conv_2(x)
195
- x = self.bn_2(x); x = F.relu(x); x = self.dropout_2(x)
 
 
196
  x = self.conv_3(x)
197
- x = self.bn_3(x); x = F.relu(x); x = self.dropout_3(x)
 
 
198
  x = x.transpose(1, 2)
199
  x = self.linear_2(x)
200
  return x
201
 
202
- class PostNet(nn.Module): # Your PostNet definition
203
  def __init__(self):
204
  super(PostNet, self).__init__()
205
  self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
@@ -221,18 +231,23 @@ class PostNet(nn.Module): # Your PostNet definition
221
  self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
222
  self.dropout_6 = torch.nn.Dropout(0.5)
223
  def forward(self, x):
224
- x_orig = x # Store original for residual connection if postnet predicts residual
225
  x = x.transpose(2, 1)
226
- x = self.conv_1(x); x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
227
- x = self.conv_2(x); x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
228
- x = self.conv_3(x); x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
229
- x = self.conv_4(x); x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
230
- x = self.conv_5(x); x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
231
- x = self.conv_6(x); x = self.bn_6(x); x = self.dropout_6(x) # No Tanh on last layer for mel usually
 
 
 
 
 
 
232
  x = x.transpose(1, 2)
233
- return x # This is the residual, should be added to original mel_linear
234
 
235
- class DecoderPreNet(nn.Module): # Your DecoderPreNet definition
236
  def __init__(self):
237
  super(DecoderPreNet, self).__init__()
238
  self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
@@ -240,13 +255,13 @@ class DecoderPreNet(nn.Module): # Your DecoderPreNet definition
240
  def forward(self, x):
241
  x = self.linear_1(x)
242
  x = F.relu(x)
243
- x = F.dropout(x, p=0.5, training=self.training)
244
  x = self.linear_2(x)
245
- x = F.relu(x)
246
- x = F.dropout(x, p=0.5, training=self.training)
247
- return x
248
 
249
- class TransformerTTS(nn.Module): # Your TransformerTTS definition
250
  def __init__(self, device=DEVICE):
251
  super(TransformerTTS, self).__init__()
252
  self.encoder_prenet = EncoderPreNet()
@@ -259,132 +274,56 @@ class TransformerTTS(nn.Module): # Your TransformerTTS definition
259
  self.decoder_block_1 = DecoderBlock()
260
  self.decoder_block_2 = DecoderBlock()
261
  self.decoder_block_3 = DecoderBlock()
262
- self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
263
- self.linear_2 = nn.Linear(hp.embedding_size, 1) # Stop token
264
  self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)
265
- self.device = device
266
-
267
- def forward(self, text, text_len, mel, mel_len): # For training/teacher-forcing
268
- # ... (Your detailed forward pass for training, with all masks)
269
  N = text.shape[0]; S = text.shape[1]; TIME = mel.shape[1]
270
- current_device = text.device
271
-
272
- src_key_padding_mask = torch.zeros((N, S), device=current_device, dtype=torch.bool).masked_fill(~mask_from_seq_lengths(text_len, max_length=S), True)
273
- src_mask = None # Typically encoder self-attention doesn't use a causal mask
274
-
275
- tgt_key_padding_mask = torch.zeros((N, TIME), device=current_device, dtype=torch.bool).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME), True)
276
- tgt_mask = torch.zeros((TIME, TIME), device=current_device).masked_fill(torch.triu(torch.full((TIME, TIME), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))
277
- memory_mask = None # Cross-attention mask, typically not needed unless specific structure
278
-
279
- text_x = self.encoder_prenet(text)
280
- pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
281
- text_s_dim = text_x.shape[1]
282
- text_x = text_x + pos_codes[:text_s_dim]
283
-
284
- text_x = self.encoder_block_1(text_x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
285
- text_x = self.encoder_block_2(text_x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
286
- text_x = self.encoder_block_3(text_x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
287
- memory = self.norm_memory(text_x)
288
-
289
- mel_x = self.decoder_prenet(mel)
290
- mel_time_dim = mel_x.shape[1]
291
- mel_x = mel_x + pos_codes[:mel_time_dim]
292
-
293
- mel_x = self.decoder_block_1(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
294
- mel_x = self.decoder_block_2(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
295
- mel_x = self.decoder_block_3(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
296
-
297
  mel_linear = self.linear_1(mel_x)
298
- mel_postnet_residual = self.postnet(mel_linear) # Postnet predicts residual
299
- mel_postnet = mel_linear + mel_postnet_residual
300
-
301
- stop_token = self.linear_2(mel_x) # Sigmoid applied later
302
-
303
- # Masking for training outputs
304
- bool_mel_mask = tgt_key_padding_mask.unsqueeze(-1).repeat(1, 1, hp.mel_freq)
305
- mel_linear = mel_linear.masked_fill(bool_mel_mask, 0.0)
306
- mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0.0)
307
- # Ensure stop_token is [N, TIME]
308
- stop_token = stop_token.masked_fill(tgt_key_padding_mask.unsqueeze(-1) if stop_token.dim() == 3 else tgt_key_padding_mask, 1e3)
309
- if stop_token.dim() == 3 and stop_token.shape[2] == 1:
310
- stop_token = stop_token.squeeze(-1)
311
-
312
-
313
- return mel_postnet, mel_linear, stop_token
314
-
315
 
316
  @torch.no_grad()
317
- def inference(self, text, max_length=800, stop_token_threshold=0.5): # text: [1, seq_len]
318
- self.eval()
319
- N = text.shape[0] # Should be 1
320
- current_device = text.device
321
- text_lengths = torch.tensor([text.shape[1]], device=current_device)
322
-
323
- # Encoder pass (once)
324
- src_key_padding_mask_inf = torch.zeros((N, text.shape[1]), device=current_device, dtype=torch.bool) # All False initially
325
- # No, src_key_padding_mask should be based on actual text length, even if N=1, S=text.shape[1]
326
- # For inference with single item, it's often all False (no padding in input text usually)
327
- # However, to be consistent with how `mask_from_seq_lengths` works:
328
- src_key_padding_mask_inf = ~mask_from_seq_lengths(text_lengths, text.shape[1])
329
-
330
-
331
- encoder_output = self.encoder_prenet(text)
332
- pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
333
- text_s_dim = encoder_output.shape[1]
334
- encoder_output = encoder_output + pos_codes[:text_s_dim]
335
-
336
- encoder_output = self.encoder_block_1(encoder_output, key_padding_mask=src_key_padding_mask_inf)
337
- encoder_output = self.encoder_block_2(encoder_output, key_padding_mask=src_key_padding_mask_inf)
338
- encoder_output = self.encoder_block_3(encoder_output, key_padding_mask=src_key_padding_mask_inf)
339
- memory = self.norm_memory(encoder_output)
340
-
341
- # Decoder pass (iterative)
342
- mel_input = torch.zeros((N, 1, hp.mel_freq), device=current_device) # SOS frame
343
- generated_mel_frames = []
344
-
345
- for i in range(max_length):
346
- mel_lengths_inf = torch.tensor([mel_input.shape[1]], device=current_device)
347
- # For decoder self-attention, causal mask is needed
348
- tgt_mask_inf = torch.zeros((mel_input.shape[1], mel_input.shape[1]), device=current_device).masked_fill(
349
- torch.triu(torch.full((mel_input.shape[1], mel_input.shape[1]), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf")
350
- )
351
- # Decoder input padding mask (all False as we build it frame by frame, no padding yet)
352
- tgt_key_padding_mask_inf = torch.zeros((N, mel_input.shape[1]), device=current_device, dtype=torch.bool)
353
-
354
-
355
- mel_x = self.decoder_prenet(mel_input)
356
- mel_time_dim = mel_input.shape[1]
357
- mel_x = mel_x + pos_codes[:mel_time_dim] # Positional encoding for current mel sequence
358
-
359
- mel_x = self.decoder_block_1(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
360
- mel_x = self.decoder_block_2(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
361
- mel_x = self.decoder_block_3(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
362
-
363
- mel_linear_step = self.linear_1(mel_x[:, -1:, :]) # Predict only for the last frame
364
- mel_postnet_residual_step = self.postnet(mel_linear_step)
365
- current_mel_frame = mel_linear_step + mel_postnet_residual_step
366
-
367
- generated_mel_frames.append(current_mel_frame)
368
- mel_input = torch.cat([mel_input, current_mel_frame], dim=1) # Append to input for next step
369
-
370
- # Stop token prediction (based on the last decoder output before linear to mel)
371
- stop_token_logit = self.linear_2(mel_x[:, -1:, :]) # Stop token from last frame's decoder hidden state
372
- stop_token_prob = torch.sigmoid(stop_token_logit.squeeze())
373
-
374
- if stop_token_prob > stop_token_threshold:
375
- # print(f"Stop token threshold reached at step {i+1}")
376
- break
377
- if mel_input.shape[1] > hp.max_mel_time -1: # Safety break based on max_mel_time
378
- # print(f"Max mel time {hp.max_mel_time} almost reached.")
379
  break
380
-
381
-
382
- if not generated_mel_frames:
383
- print("Warning: TTS inference produced no mel frames.")
384
- return torch.zeros((N, 0, hp.mel_freq), device=current_device) # Return empty tensor
385
-
386
- final_mel_output = torch.cat(generated_mel_frames, dim=1)
387
- return final_mel_output # Removed stop_token_outputs as it's not used by caller
388
  # --- (End of your model definitions) ---
389
 
390
  # --- Part 2: Model Loading ---
@@ -395,6 +334,7 @@ MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
395
 
396
 
397
 
 
398
  # Wrap model loading in a function to clearly see when it happens or to potentially delay it.
399
  # For Spaces, global loading is fine and preferred as it happens once.
400
  print("--- Starting Model Loading ---")
 
45
  # --- (Start of your model definitions - make sure this is complete from your previous code) ---
46
  class Hyperparams:
47
  seed = 42
48
+ # We won't use these dataset paths, but keep them for hp object integrity
49
  csv_path = "path/to/metadata.csv"
50
  wav_path = "path/to/wavs"
51
  symbols = [
52
+ 'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
53
+ 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
54
+ 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü',
55
+ '’', '“', '”'
56
  ]
57
  sr = 22050
58
  n_fft = 2048
 
62
  mel_freq = 128
63
  max_mel_time = 1024
64
  power = 2.0
65
+ text_num_embeddings = 2*len(symbols)
66
  embedding_size = 256
67
+ encoder_embedding_size = 512
68
  dim_feedforward = 1024
69
  postnet_embedding_size = 1024
70
  encoder_kernel_size = 3
 
79
 
80
  hp = Hyperparams()
81
 
82
+ # Text to Sequence
83
  symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
84
  def text_to_seq(text):
85
  text = text.lower()
 
86
  seq = []
87
  for s in text:
88
  _id = symbol_to_id.get(s, None)
 
91
  seq.append(symbol_to_id["EOS"])
92
  return torch.IntTensor(seq)
93
 
94
+ # Audio Processing
95
  spec_transform = torchaudio.transforms.Spectrogram(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=hp.power)
96
  mel_scale_transform = torchaudio.transforms.MelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft)
97
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
98
  mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
99
  griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length).to(DEVICE)
100
 
 
105
 
106
  def db_to_power_mel_spec(mel_spec):
107
  mel_spec = mel_spec*hp.scale_db
108
+ mel_spec = torchaudio.functional.DB_to_amplitude(mel_spec, ref=hp.ampl_ref, power=hp.ampl_power)
109
  return mel_spec
110
 
111
  def inverse_mel_spec_to_wav(mel_spec):
 
118
  ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
119
  range_tensor = ones.cumsum(dim=1)
120
  return sequence_lengths.unsqueeze(1) >= range_tensor
121
+
122
+ # --- TransformerTTS Model Architecture (Copied from notebook)
123
+ class EncoderBlock(nn.Module):
124
  def __init__(self):
125
  super(EncoderBlock, self).__init__()
126
  self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
 
135
  x_out = self.norm_1(x)
136
  x_out, _ = self.attn(query=x_out, key=x_out, value=x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
137
  x_out = self.dropout_1(x_out)
138
+ x = x + x_out
139
+ x_out = self.norm_2(x)
140
  x_out = self.linear_1(x_out)
141
  x_out = F.relu(x_out)
142
  x_out = self.dropout_2(x_out)
 
145
  x = x + x_out
146
  return x
147
 
148
+ class DecoderBlock(nn.Module):
149
  def __init__(self):
150
  super(DecoderBlock, self).__init__()
151
  self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
152
  self.self_attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
153
  self.dropout_1 = torch.nn.Dropout(0.1)
154
  self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
155
+ self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
156
  self.dropout_2 = torch.nn.Dropout(0.1)
157
  self.norm_3 = nn.LayerNorm(normalized_shape=hp.embedding_size)
158
  self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
 
174
  x = self.norm_3(x + x_out)
175
  return x
176
 
177
+ class EncoderPreNet(nn.Module):
178
  def __init__(self):
179
  super(EncoderPreNet, self).__init__()
180
  self.embedding = nn.Embedding(num_embeddings=hp.text_num_embeddings, embedding_dim=hp.encoder_embedding_size)
 
188
  self.dropout_2 = torch.nn.Dropout(0.5)
189
  self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
190
  self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size)
191
+ self.dropout_3 = torch.nn.Dropout(0.5)
192
  def forward(self, text):
193
  x = self.embedding(text)
194
  x = self.linear_1(x)
195
  x = x.transpose(2, 1)
196
  x = self.conv_1(x)
197
+ x = self.bn_1(x)
198
+ x = F.relu(x)
199
+ x = self.dropout_1(x)
200
  x = self.conv_2(x)
201
+ x = self.bn_2(x)
202
+ x = F.relu(x)
203
+ x = self.dropout_2(x)
204
  x = self.conv_3(x)
205
+ x = self.bn_3(x)
206
+ x = F.relu(x)
207
+ x = self.dropout_3(x)
208
  x = x.transpose(1, 2)
209
  x = self.linear_2(x)
210
  return x
211
 
212
+ class PostNet(nn.Module):
213
  def __init__(self):
214
  super(PostNet, self).__init__()
215
  self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
 
231
  self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
232
  self.dropout_6 = torch.nn.Dropout(0.5)
233
  def forward(self, x):
 
234
  x = x.transpose(2, 1)
235
+ x = self.conv_1(x)
236
+ x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
237
+ x = self.conv_2(x)
238
+ x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
239
+ x = self.conv_3(x)
240
+ x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
241
+ x = self.conv_4(x)
242
+ x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
243
+ x = self.conv_5(x)
244
+ x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
245
+ x = self.conv_6(x)
246
+ x = self.bn_6(x); x = self.dropout_6(x)
247
  x = x.transpose(1, 2)
248
+ return x
249
 
250
+ class DecoderPreNet(nn.Module):
251
  def __init__(self):
252
  super(DecoderPreNet, self).__init__()
253
  self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
 
255
  def forward(self, x):
256
  x = self.linear_1(x)
257
  x = F.relu(x)
258
+ x = F.dropout(x, p=0.5, training=True)
259
  x = self.linear_2(x)
260
+ x = F.relu(x)
261
+ x = F.dropout(x, p=0.5, training=True)
262
+ return x
263
 
264
+ class TransformerTTS(nn.Module):
265
  def __init__(self, device=DEVICE):
266
  super(TransformerTTS, self).__init__()
267
  self.encoder_prenet = EncoderPreNet()
 
274
  self.decoder_block_1 = DecoderBlock()
275
  self.decoder_block_2 = DecoderBlock()
276
  self.decoder_block_3 = DecoderBlock()
277
+ self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
278
+ self.linear_2 = nn.Linear(hp.embedding_size, 1)
279
  self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)
280
+ def forward(self, text, text_len, mel, mel_len):
 
 
 
281
  N = text.shape[0]; S = text.shape[1]; TIME = mel.shape[1]
282
+ self.src_key_padding_mask = torch.zeros((N, S), device=text.device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S), float("-inf"))
283
+ self.src_mask = torch.zeros((S, S), device=text.device).masked_fill(torch.triu(torch.full((S, S), True, dtype=torch.bool), diagonal=1).to(text.device), float("-inf"))
284
+ self.tgt_key_padding_mask = torch.zeros((N, TIME), device=mel.device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME), float("-inf"))
285
+ self.tgt_mask = torch.zeros((TIME, TIME), device=mel.device).masked_fill(torch.triu(torch.full((TIME, TIME), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
286
+ self.memory_mask = torch.zeros((TIME, S), device=mel.device).masked_fill(torch.triu(torch.full((TIME, S), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
287
+ text_x = self.encoder_prenet(text)
288
+ pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time).to(mel.device))
289
+ S = text_x.shape[1]; text_x = text_x + pos_codes[:S]
290
+ text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
291
+ text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
292
+ text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
293
+ text_x = self.norm_memory(text_x)
294
+ mel_x = self.decoder_prenet(mel); mel_x = mel_x + pos_codes[:TIME]
295
+ mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
296
+ mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
297
+ mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
 
 
 
 
 
 
 
 
 
 
 
298
  mel_linear = self.linear_1(mel_x)
299
+ mel_postnet = self.postnet(mel_linear)
300
+ mel_postnet = mel_linear + mel_postnet
301
+ stop_token = self.linear_2(mel_x)
302
+ bool_mel_mask = self.tgt_key_padding_mask.ne(0).unsqueeze(-1).repeat(1, 1, hp.mel_freq)
303
+ mel_linear = mel_linear.masked_fill(bool_mel_mask, 0)
304
+ mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0)
305
+ stop_token = stop_token.masked_fill(bool_mel_mask[:, :, 0].unsqueeze(-1), 1e3).squeeze(2)
306
+ return mel_postnet, mel_linear, stop_token
 
 
 
 
 
 
 
 
 
307
 
308
  @torch.no_grad()
309
+ def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=True):
310
+ self.eval(); self.train(False)
311
+ text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
312
+ N = 1
313
+ SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
314
+ mel_padded = SOS
315
+ mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
+ stop_token_outputs = torch.FloatTensor([]).to(text.device)
317
+ iters = range(max_length)
318
+ for _ in iters:
319
+ mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
320
+ mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
321
+ if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  break
323
+ else:
324
+ stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
325
+ mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
326
+ return mel_postnet, stop_token_outputs
 
 
 
 
327
  # --- (End of your model definitions) ---
328
 
329
  # --- Part 2: Model Loading ---
 
334
 
335
 
336
 
337
+
338
  # Wrap model loading in a function to clearly see when it happens or to potentially delay it.
339
  # For Spaces, global loading is fine and preferred as it happens once.
340
  print("--- Starting Model Loading ---")