MoHamdyy commited on
Commit
9e8a757
·
1 Parent(s): 614d169

initial commit

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +573 -0
  3. requirements.txt +24 -0
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: Translation Stack
3
  emoji: 🏢
4
- colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.33.1
 
1
  ---
2
  title: Translation Stack
3
  emoji: 🏢
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.33.1
app.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ import random
5
+ import numpy as np
6
+ import math
7
+ import shutil
8
+ # import base64 # Not directly needed for Gradio filepath output
9
+
10
+ # Torch and Audio
11
+ import torch
12
+ import torch.nn as nn
13
+ # import torch.optim as optim # Not needed for inference
14
+ # from torch.utils.data import Dataset, DataLoader # Not needed for inference
15
+ import torch.nn.functional as F
16
+ import torchaudio
17
+ import librosa
18
+ # import librosa.display # Not used in pipeline
19
+
20
+ # Text and Audio Processing
21
+ from unidecode import unidecode
22
+ # from inflect import engine # Not explicitly used in pipeline, consider removing
23
+ # import pydub # Not explicitly used in pipeline, consider removing
24
+ import soundfile as sf
25
+
26
+ # Transformers
27
+ from transformers import (
28
+ WhisperProcessor, WhisperForConditionalGeneration,
29
+ MarianTokenizer, MarianMTModel,
30
+ )
31
+ from huggingface_hub import hf_hub_download
32
+
33
+ # Gradio and Hugging Face Spaces
34
+ import gradio as gr
35
+ import spaces # <<< --- ADD THIS IMPORT --- <<<
36
+
37
+ # --- Global Configuration & Device Setup ---
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+ print(f"--- Initializing on device: {DEVICE} ---") # This will run when the Space builds/starts
40
+
41
+ # --- Part 1: TTS Model Components (Your Custom TTS) ---
42
+ # ... (Keep all your Hyperparams, text_to_seq, audio processing for TTS, and Model class definitions:
43
+ # EncoderBlock, DecoderBlock, EncoderPreNet, PostNet, DecoderPreNet, TransformerTTS)
44
+ # ... (Ensure TransformerTTS and its sub-modules are correctly defined as in your previous code)
45
+ # --- (Start of your model definitions - make sure this is complete from your previous code) ---
46
+ class Hyperparams:
47
+ seed = 42
48
+ csv_path = "path/to/metadata.csv"
49
+ wav_path = "path/to/wavs"
50
+ symbols = [
51
+ 'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
52
+ 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
53
+ 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü',
54
+ '’', '“', '”'
55
+ ]
56
+ sr = 22050
57
+ n_fft = 2048
58
+ n_stft = int((n_fft//2) + 1)
59
+ hop_length = int(n_fft/8.0)
60
+ win_length = int(n_fft/2.0)
61
+ mel_freq = 128
62
+ max_mel_time = 1024
63
+ power = 2.0
64
+ text_num_embeddings = 2*len(symbols)
65
+ embedding_size = 256
66
+ encoder_embedding_size = 512
67
+ dim_feedforward = 1024
68
+ postnet_embedding_size = 1024
69
+ encoder_kernel_size = 3
70
+ postnet_kernel_size = 5
71
+ ampl_multiplier = 10.0
72
+ ampl_amin = 1e-10
73
+ db_multiplier = 1.0
74
+ ampl_ref = 1.0
75
+ ampl_power = 1.0
76
+ max_db = 100
77
+ scale_db = 10
78
+
79
+ hp = Hyperparams()
80
+
81
+ symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
82
+ def text_to_seq(text):
83
+ text = text.lower()
84
+ text = unidecode(text)
85
+ seq = []
86
+ for s in text:
87
+ _id = symbol_to_id.get(s, None)
88
+ if _id is not None:
89
+ seq.append(_id)
90
+ seq.append(symbol_to_id["EOS"])
91
+ return torch.IntTensor(seq)
92
+
93
+ spec_transform = torchaudio.transforms.Spectrogram(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=hp.power)
94
+ mel_scale_transform = torchaudio.transforms.MelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft)
95
+ mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
96
+ griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length).to(DEVICE)
97
+
98
+ def pow_to_db_mel_spec(mel_spec):
99
+ mel_spec = torchaudio.functional.amplitude_to_DB(mel_spec, multiplier=hp.ampl_multiplier, amin=hp.ampl_amin, db_multiplier=hp.db_multiplier, top_db=hp.max_db)
100
+ mel_spec = mel_spec/hp.scale_db
101
+ return mel_spec
102
+
103
+ def db_to_power_mel_spec(mel_spec):
104
+ mel_spec = mel_spec*hp.scale_db
105
+ mel_spec = torchaudio.functional.DB_to_amplitude(mel_spec, ref=hp.ampl_ref, power=hp.ampl_power)
106
+ return mel_spec
107
+
108
+ def inverse_mel_spec_to_wav(mel_spec):
109
+ power_mel_spec = db_to_power_mel_spec(mel_spec.to(DEVICE))
110
+ spectrogram = mel_inverse_transform(power_mel_spec)
111
+ pseudo_wav = griffnlim_transform(spectrogram)
112
+ return pseudo_wav
113
+
114
+ def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.BoolTensor:
115
+ ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
116
+ range_tensor = ones.cumsum(dim=1)
117
+ return sequence_lengths.unsqueeze(1) >= range_tensor
118
+
119
+ class EncoderBlock(nn.Module): # Your EncoderBlock definition
120
+ def __init__(self):
121
+ super(EncoderBlock, self).__init__()
122
+ self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
123
+ self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
124
+ self.dropout_1 = torch.nn.Dropout(0.1)
125
+ self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
126
+ self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
127
+ self.dropout_2 = torch.nn.Dropout(0.1)
128
+ self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
129
+ self.dropout_3 = torch.nn.Dropout(0.1)
130
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
131
+ x_out = self.norm_1(x)
132
+ x_out, _ = self.attn(query=x_out, key=x_out, value=x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
133
+ x_out = self.dropout_1(x_out)
134
+ x = x + x_out
135
+ x_out = self.norm_2(x)
136
+ x_out = self.linear_1(x_out)
137
+ x_out = F.relu(x_out)
138
+ x_out = self.dropout_2(x_out)
139
+ x_out = self.linear_2(x_out)
140
+ x_out = self.dropout_3(x_out)
141
+ x = x + x_out
142
+ return x
143
+
144
+ class DecoderBlock(nn.Module): # Your DecoderBlock definition
145
+ def __init__(self):
146
+ super(DecoderBlock, self).__init__()
147
+ self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
148
+ self.self_attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
149
+ self.dropout_1 = torch.nn.Dropout(0.1)
150
+ self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
151
+ self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
152
+ self.dropout_2 = torch.nn.Dropout(0.1)
153
+ self.norm_3 = nn.LayerNorm(normalized_shape=hp.embedding_size)
154
+ self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
155
+ self.dropout_3 = torch.nn.Dropout(0.1)
156
+ self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
157
+ self.dropout_4 = torch.nn.Dropout(0.1)
158
+ def forward(self, x, memory, x_attn_mask=None, x_key_padding_mask=None, memory_attn_mask=None, memory_key_padding_mask=None):
159
+ x_out, _ = self.self_attn(query=x, key=x, value=x, attn_mask=x_attn_mask, key_padding_mask=x_key_padding_mask)
160
+ x_out = self.dropout_1(x_out)
161
+ x = self.norm_1(x + x_out)
162
+ x_out, _ = self.attn(query=x, key=memory, value=memory, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask)
163
+ x_out = self.dropout_2(x_out)
164
+ x = self.norm_2(x + x_out)
165
+ x_out = self.linear_1(x)
166
+ x_out = F.relu(x_out)
167
+ x_out = self.dropout_3(x_out)
168
+ x_out = self.linear_2(x_out)
169
+ x_out = self.dropout_4(x_out)
170
+ x = self.norm_3(x + x_out)
171
+ return x
172
+
173
+ class EncoderPreNet(nn.Module): # Your EncoderPreNet definition
174
+ def __init__(self):
175
+ super(EncoderPreNet, self).__init__()
176
+ self.embedding = nn.Embedding(num_embeddings=hp.text_num_embeddings, embedding_dim=hp.encoder_embedding_size)
177
+ self.linear_1 = nn.Linear(hp.encoder_embedding_size, hp.encoder_embedding_size)
178
+ self.linear_2 = nn.Linear(hp.encoder_embedding_size, hp.embedding_size)
179
+ self.conv_1 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
180
+ self.bn_1 = nn.BatchNorm1d(hp.encoder_embedding_size)
181
+ self.dropout_1 = torch.nn.Dropout(0.5)
182
+ self.conv_2 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
183
+ self.bn_2 = nn.BatchNorm1d(hp.encoder_embedding_size)
184
+ self.dropout_2 = torch.nn.Dropout(0.5)
185
+ self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
186
+ self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size)
187
+ self.dropout_3 = torch.nn.Dropout(0.5)
188
+ def forward(self, text):
189
+ x = self.embedding(text)
190
+ x = self.linear_1(x)
191
+ x = x.transpose(2, 1)
192
+ x = self.conv_1(x)
193
+ x = self.bn_1(x); x = F.relu(x); x = self.dropout_1(x)
194
+ x = self.conv_2(x)
195
+ x = self.bn_2(x); x = F.relu(x); x = self.dropout_2(x)
196
+ x = self.conv_3(x)
197
+ x = self.bn_3(x); x = F.relu(x); x = self.dropout_3(x)
198
+ x = x.transpose(1, 2)
199
+ x = self.linear_2(x)
200
+ return x
201
+
202
+ class PostNet(nn.Module): # Your PostNet definition
203
+ def __init__(self):
204
+ super(PostNet, self).__init__()
205
+ self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
206
+ self.bn_1 = nn.BatchNorm1d(hp.postnet_embedding_size)
207
+ self.dropout_1 = torch.nn.Dropout(0.5)
208
+ self.conv_2 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
209
+ self.bn_2 = nn.BatchNorm1d(hp.postnet_embedding_size)
210
+ self.dropout_2 = torch.nn.Dropout(0.5)
211
+ self.conv_3 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
212
+ self.bn_3 = nn.BatchNorm1d(hp.postnet_embedding_size)
213
+ self.dropout_3 = torch.nn.Dropout(0.5)
214
+ self.conv_4 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
215
+ self.bn_4 = nn.BatchNorm1d(hp.postnet_embedding_size)
216
+ self.dropout_4 = torch.nn.Dropout(0.5)
217
+ self.conv_5 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
218
+ self.bn_5 = nn.BatchNorm1d(hp.postnet_embedding_size)
219
+ self.dropout_5 = torch.nn.Dropout(0.5)
220
+ self.conv_6 = nn.Conv1d(hp.postnet_embedding_size, hp.mel_freq, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
221
+ self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
222
+ self.dropout_6 = torch.nn.Dropout(0.5)
223
+ def forward(self, x):
224
+ x_orig = x # Store original for residual connection if postnet predicts residual
225
+ x = x.transpose(2, 1)
226
+ x = self.conv_1(x); x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
227
+ x = self.conv_2(x); x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
228
+ x = self.conv_3(x); x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
229
+ x = self.conv_4(x); x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
230
+ x = self.conv_5(x); x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
231
+ x = self.conv_6(x); x = self.bn_6(x); x = self.dropout_6(x) # No Tanh on last layer for mel usually
232
+ x = x.transpose(1, 2)
233
+ return x # This is the residual, should be added to original mel_linear
234
+
235
+ class DecoderPreNet(nn.Module): # Your DecoderPreNet definition
236
+ def __init__(self):
237
+ super(DecoderPreNet, self).__init__()
238
+ self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
239
+ self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size)
240
+ def forward(self, x):
241
+ x = self.linear_1(x)
242
+ x = F.relu(x)
243
+ x = F.dropout(x, p=0.5, training=self.training)
244
+ x = self.linear_2(x)
245
+ x = F.relu(x)
246
+ x = F.dropout(x, p=0.5, training=self.training)
247
+ return x
248
+
249
+ class TransformerTTS(nn.Module): # Your TransformerTTS definition
250
+ def __init__(self, device=DEVICE):
251
+ super(TransformerTTS, self).__init__()
252
+ self.encoder_prenet = EncoderPreNet()
253
+ self.decoder_prenet = DecoderPreNet()
254
+ self.postnet = PostNet()
255
+ self.pos_encoding = nn.Embedding(num_embeddings=hp.max_mel_time, embedding_dim=hp.embedding_size)
256
+ self.encoder_block_1 = EncoderBlock()
257
+ self.encoder_block_2 = EncoderBlock()
258
+ self.encoder_block_3 = EncoderBlock()
259
+ self.decoder_block_1 = DecoderBlock()
260
+ self.decoder_block_2 = DecoderBlock()
261
+ self.decoder_block_3 = DecoderBlock()
262
+ self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
263
+ self.linear_2 = nn.Linear(hp.embedding_size, 1) # Stop token
264
+ self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)
265
+ self.device = device
266
+
267
+ def forward(self, text, text_len, mel, mel_len): # For training/teacher-forcing
268
+ # ... (Your detailed forward pass for training, with all masks)
269
+ N = text.shape[0]; S = text.shape[1]; TIME = mel.shape[1]
270
+ current_device = text.device
271
+
272
+ src_key_padding_mask = torch.zeros((N, S), device=current_device, dtype=torch.bool).masked_fill(~mask_from_seq_lengths(text_len, max_length=S), True)
273
+ src_mask = None # Typically encoder self-attention doesn't use a causal mask
274
+
275
+ tgt_key_padding_mask = torch.zeros((N, TIME), device=current_device, dtype=torch.bool).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME), True)
276
+ tgt_mask = torch.zeros((TIME, TIME), device=current_device).masked_fill(torch.triu(torch.full((TIME, TIME), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))
277
+ memory_mask = None # Cross-attention mask, typically not needed unless specific structure
278
+
279
+ text_x = self.encoder_prenet(text)
280
+ pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
281
+ text_s_dim = text_x.shape[1]
282
+ text_x = text_x + pos_codes[:text_s_dim]
283
+
284
+ text_x = self.encoder_block_1(text_x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
285
+ text_x = self.encoder_block_2(text_x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
286
+ text_x = self.encoder_block_3(text_x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
287
+ memory = self.norm_memory(text_x)
288
+
289
+ mel_x = self.decoder_prenet(mel)
290
+ mel_time_dim = mel_x.shape[1]
291
+ mel_x = mel_x + pos_codes[:mel_time_dim]
292
+
293
+ mel_x = self.decoder_block_1(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
294
+ mel_x = self.decoder_block_2(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
295
+ mel_x = self.decoder_block_3(x=mel_x, memory=memory, x_attn_mask=tgt_mask, x_key_padding_mask=tgt_key_padding_mask, memory_attn_mask=memory_mask, memory_key_padding_mask=src_key_padding_mask)
296
+
297
+ mel_linear = self.linear_1(mel_x)
298
+ mel_postnet_residual = self.postnet(mel_linear) # Postnet predicts residual
299
+ mel_postnet = mel_linear + mel_postnet_residual
300
+
301
+ stop_token = self.linear_2(mel_x) # Sigmoid applied later
302
+
303
+ # Masking for training outputs
304
+ bool_mel_mask = tgt_key_padding_mask.unsqueeze(-1).repeat(1, 1, hp.mel_freq)
305
+ mel_linear = mel_linear.masked_fill(bool_mel_mask, 0.0)
306
+ mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0.0)
307
+ # Ensure stop_token is [N, TIME]
308
+ stop_token = stop_token.masked_fill(tgt_key_padding_mask.unsqueeze(-1) if stop_token.dim() == 3 else tgt_key_padding_mask, 1e3)
309
+ if stop_token.dim() == 3 and stop_token.shape[2] == 1:
310
+ stop_token = stop_token.squeeze(-1)
311
+
312
+
313
+ return mel_postnet, mel_linear, stop_token
314
+
315
+
316
+ @torch.no_grad()
317
+ def inference(self, text, max_length=800, stop_token_threshold=0.5): # text: [1, seq_len]
318
+ self.eval()
319
+ N = text.shape[0] # Should be 1
320
+ current_device = text.device
321
+ text_lengths = torch.tensor([text.shape[1]], device=current_device)
322
+
323
+ # Encoder pass (once)
324
+ src_key_padding_mask_inf = torch.zeros((N, text.shape[1]), device=current_device, dtype=torch.bool) # All False initially
325
+ # No, src_key_padding_mask should be based on actual text length, even if N=1, S=text.shape[1]
326
+ # For inference with single item, it's often all False (no padding in input text usually)
327
+ # However, to be consistent with how `mask_from_seq_lengths` works:
328
+ src_key_padding_mask_inf = ~mask_from_seq_lengths(text_lengths, text.shape[1])
329
+
330
+
331
+ encoder_output = self.encoder_prenet(text)
332
+ pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
333
+ text_s_dim = encoder_output.shape[1]
334
+ encoder_output = encoder_output + pos_codes[:text_s_dim]
335
+
336
+ encoder_output = self.encoder_block_1(encoder_output, key_padding_mask=src_key_padding_mask_inf)
337
+ encoder_output = self.encoder_block_2(encoder_output, key_padding_mask=src_key_padding_mask_inf)
338
+ encoder_output = self.encoder_block_3(encoder_output, key_padding_mask=src_key_padding_mask_inf)
339
+ memory = self.norm_memory(encoder_output)
340
+
341
+ # Decoder pass (iterative)
342
+ mel_input = torch.zeros((N, 1, hp.mel_freq), device=current_device) # SOS frame
343
+ generated_mel_frames = []
344
+
345
+ for i in range(max_length):
346
+ mel_lengths_inf = torch.tensor([mel_input.shape[1]], device=current_device)
347
+ # For decoder self-attention, causal mask is needed
348
+ tgt_mask_inf = torch.zeros((mel_input.shape[1], mel_input.shape[1]), device=current_device).masked_fill(
349
+ torch.triu(torch.full((mel_input.shape[1], mel_input.shape[1]), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf")
350
+ )
351
+ # Decoder input padding mask (all False as we build it frame by frame, no padding yet)
352
+ tgt_key_padding_mask_inf = torch.zeros((N, mel_input.shape[1]), device=current_device, dtype=torch.bool)
353
+
354
+
355
+ mel_x = self.decoder_prenet(mel_input)
356
+ mel_time_dim = mel_input.shape[1]
357
+ mel_x = mel_x + pos_codes[:mel_time_dim] # Positional encoding for current mel sequence
358
+
359
+ mel_x = self.decoder_block_1(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
360
+ mel_x = self.decoder_block_2(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
361
+ mel_x = self.decoder_block_3(x=mel_x, memory=memory, x_attn_mask=tgt_mask_inf, x_key_padding_mask=tgt_key_padding_mask_inf, memory_key_padding_mask=src_key_padding_mask_inf)
362
+
363
+ mel_linear_step = self.linear_1(mel_x[:, -1:, :]) # Predict only for the last frame
364
+ mel_postnet_residual_step = self.postnet(mel_linear_step)
365
+ current_mel_frame = mel_linear_step + mel_postnet_residual_step
366
+
367
+ generated_mel_frames.append(current_mel_frame)
368
+ mel_input = torch.cat([mel_input, current_mel_frame], dim=1) # Append to input for next step
369
+
370
+ # Stop token prediction (based on the last decoder output before linear to mel)
371
+ stop_token_logit = self.linear_2(mel_x[:, -1:, :]) # Stop token from last frame's decoder hidden state
372
+ stop_token_prob = torch.sigmoid(stop_token_logit.squeeze())
373
+
374
+ if stop_token_prob > stop_token_threshold:
375
+ # print(f"Stop token threshold reached at step {i+1}")
376
+ break
377
+ if mel_input.shape[1] > hp.max_mel_time -1: # Safety break based on max_mel_time
378
+ # print(f"Max mel time {hp.max_mel_time} almost reached.")
379
+ break
380
+
381
+
382
+ if not generated_mel_frames:
383
+ print("Warning: TTS inference produced no mel frames.")
384
+ return torch.zeros((N, 0, hp.mel_freq), device=current_device) # Return empty tensor
385
+
386
+ final_mel_output = torch.cat(generated_mel_frames, dim=1)
387
+ return final_mel_output # Removed stop_token_outputs as it's not used by caller
388
+ # --- (End of your model definitions) ---
389
+
390
+ # --- Part 2: Model Loading ---
391
+ # (Same as before - ensure TTS_MODEL = TransformerTTS(device=DEVICE).to(DEVICE) is used)
392
+ TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
393
+ ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
394
+ MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
395
+
396
+ TTS_MODEL = None
397
+ stt_processor = None
398
+ stt_model = None
399
+ mt_tokenizer = None
400
+ mt_model = None
401
+
402
+ # Wrap model loading in a function to clearly see when it happens or to potentially delay it.
403
+ # For Spaces, global loading is fine and preferred as it happens once.
404
+ print("--- Starting Model Loading ---")
405
+ try:
406
+ print(f"Loading TTS model ({TTS_MODEL_HUB_ID}) to {DEVICE}...")
407
+ tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
408
+ state = torch.load(tts_model_path, map_location=DEVICE) # Load to target device directly
409
+ TTS_MODEL = TransformerTTS(device=DEVICE).to(DEVICE)
410
+ model_state_dict = state.get("model", state.get("state_dict", state))
411
+ TTS_MODEL.load_state_dict(model_state_dict)
412
+ TTS_MODEL.eval()
413
+ print("TTS model loaded successfully.")
414
+ except Exception as e:
415
+ print(f"Error loading TTS model: {e}")
416
+
417
+ try:
418
+ print(f"Loading STT (Whisper) model ({ASR_HUB_ID}) to {DEVICE}...")
419
+ stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
420
+ stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
421
+ print("STT model loaded successfully.")
422
+ except Exception as e:
423
+ print(f"Error loading STT model: {e}")
424
+
425
+ try:
426
+ print(f"Loading TTT (MarianMT) model ({MARIAN_HUB_ID}) to {DEVICE}...")
427
+ mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
428
+ mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
429
+ print("TTT model loaded successfully.")
430
+ except Exception as e:
431
+ print(f"Error loading TTT model: {e}")
432
+ print("--- Model Loading Complete ---")
433
+
434
+
435
+ # --- Part 3: Full Pipeline Function for Gradio ---
436
+ @spaces.GPU # <<< --- APPLY THE DECORATOR HERE --- <<<
437
+ def full_speech_translation_pipeline_gradio(audio_input_path):
438
+ # This print will show the device context *inside* the decorated function
439
+ # For ZeroGPU, this should ideally show 'cuda:X' when the function is executed
440
+ current_processing_device = next(stt_model.parameters()).device if stt_model else "CPU (STT model not loaded)"
441
+ print(f"--- @spaces.GPU function: Pipeline running on device: {current_processing_device} ---")
442
+
443
+
444
+ if not all([TTS_MODEL, stt_processor, stt_model, mt_tokenizer, mt_model]):
445
+ error_msg = "Critical Error: One or more models failed to load during Space initialization. Cannot process."
446
+ print(error_msg)
447
+ # Raising gr.Error is better for UI feedback
448
+ raise gr.Error(error_msg)
449
+
450
+
451
+ if audio_input_path is None:
452
+ # This case should ideally be handled by Gradio's input validation or a check before calling.
453
+ # If it still occurs, provide a clear message.
454
+ raise gr.Error("No audio file provided. Please upload an audio file.")
455
+
456
+ print(f"--- GRADIO PIPELINE START (GPU context): Processing {audio_input_path} ---")
457
+
458
+ # STT Stage
459
+ arabic_transcript = "STT Error: Processing failed."
460
+ try:
461
+ print("STT: Loading and resampling audio...")
462
+ wav, sr = torchaudio.load(audio_input_path)
463
+ if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
464
+ target_sr_stt = stt_processor.feature_extractor.sampling_rate
465
+ if sr != target_sr_stt: wav = torchaudio.transforms.Resample(sr, target_sr_stt)(wav)
466
+ # Move wav to the STT model's device *before* converting to numpy if STT model is on GPU
467
+ audio_array_stt = wav.to(DEVICE).squeeze().cpu().numpy() # Process on DEVICE, then to CPU for numpy
468
+
469
+ print("STT: Extracting features and transcribing...")
470
+ # Ensure inputs are on the same device as the model
471
+ inputs_stt = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
472
+ forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
473
+ with torch.no_grad():
474
+ generated_ids = stt_model.generate(inputs_stt, forced_decoder_ids=forced_ids, max_new_tokens=256)
475
+ arabic_transcript = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
476
+ print(f"STT Output: {arabic_transcript}")
477
+ except Exception as e:
478
+ print(f"STT Error: {e}")
479
+ raise gr.Error(f"STT processing failed: {e}")
480
+
481
+
482
+ # TTT Stage
483
+ english_translation = "TTT Error: Processing failed."
484
+ if arabic_transcript and not arabic_transcript.startswith("STT Error"):
485
+ try:
486
+ print("TTT: Translating to English...")
487
+ batch = mt_tokenizer(arabic_transcript, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
488
+ with torch.no_grad():
489
+ translated_ids = mt_model.generate(**batch, max_length=512) # max_new_tokens can also be used
490
+ english_translation = mt_tokenizer.batch_decode(translated_ids, skip_special_tokens=True)[0].strip()
491
+ print(f"TTT Output: {english_translation}")
492
+ except Exception as e:
493
+ print(f"TTT Error: {e}")
494
+ raise gr.Error(f"TTT processing failed: {e}")
495
+
496
+ else:
497
+ if not arabic_transcript or arabic_transcript.startswith("STT Error"):
498
+ english_translation = "(Skipped TTT due to STT failure or empty STT output)"
499
+ print(english_translation)
500
+
501
+
502
+ # TTS Stage
503
+ output_tts_audio_filepath = None
504
+ if english_translation and not english_translation.startswith("TTT Error") and TTS_MODEL:
505
+ try:
506
+ print("TTS: Synthesizing English speech...")
507
+ if not english_translation.strip():
508
+ print("TTS Warning: Empty string for synthesis. Skipping TTS.")
509
+ else:
510
+ sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
511
+ # max_length for TTS inference refers to max output mel frames
512
+ generated_mel = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time - 50, stop_token_threshold=0.5)
513
+
514
+ print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
515
+ if generated_mel is not None and generated_mel.numel() > 0 and generated_mel.shape[1] > 0 :
516
+ # TTS model's inverse_mel_spec_to_wav expects mel on DEVICE and returns wav on CPU
517
+ # The mel from inference should be [N, mel_len, mel_bins]
518
+ # inverse_mel_spec_to_wav might expect [mel_bins, mel_len]
519
+ mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1) # to [mel_len, mel_bins]
520
+ audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder) # This function handles .to(DEVICE) internally
521
+ synthesized_audio_np = audio_tensor.cpu().numpy() # Ensure output is on CPU for soundfile
522
+ print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
523
+
524
+ timestamp = int(time.time()*1000) # more unique
525
+ output_tts_audio_filepath = f"output_audio_{timestamp}.wav"
526
+ sf.write(output_tts_audio_filepath, synthesized_audio_np, hp.sr)
527
+ print(f"TTS: Synthesized audio saved to: {output_tts_audio_filepath}")
528
+ else:
529
+ print("TTS Warning: Generated mel spectrogram was empty or invalid.")
530
+ except Exception as e:
531
+ print(f"TTS Error: {e}")
532
+ # Do not raise gr.Error here if a partial result (text) is still useful
533
+ # output_tts_audio_filepath will remain None
534
+ english_translation += f" (TTS Error: {e})" # Append error to text
535
+ else:
536
+ if not TTS_MODEL: print("TTS SKIPPED: Model not loaded.")
537
+ elif not (english_translation and not english_translation.startswith("TTT Error")):
538
+ print("TTS SKIPPED: (Due to TTT failure or empty TTT output)")
539
+
540
+
541
+ print(f"--- GRADIO PIPELINE END (GPU context) ---")
542
+ return arabic_transcript, english_translation, output_tts_audio_filepath
543
+
544
+
545
+ # --- Part 4: Gradio Interface Definition ---
546
+ # (Same as before)
547
+ iface = gr.Interface(
548
+ fn=full_speech_translation_pipeline_gradio,
549
+ inputs=[
550
+ gr.Audio(type="filepath", label="Upload Arabic Speech")
551
+ ],
552
+ outputs=[
553
+ gr.Textbox(label="Arabic Transcript (STT)"),
554
+ gr.Textbox(label="English Translation (TTT)"),
555
+ gr.Audio(label="Synthesized English Speech (TTS)", type="filepath")
556
+ ],
557
+ title="Arabic to English Speech Translation (ZeroGPU)",
558
+ description="Upload an Arabic audio file. Transcribed to Arabic (Whisper), translated to English (MarianMT), synthesized to English speech (Custom TransformerTTS).",
559
+ allow_flagging="never",
560
+ # examples=[["sample.wav"]] # If you add a sample.wav to your repo
561
+ )
562
+
563
+ # --- Part 5: Launch for Spaces (and local testing) ---
564
+ if __name__ == '__main__':
565
+ # Clean up temp audio files from previous local runs
566
+ for f_name in os.listdir("."):
567
+ if f_name.startswith("output_audio_") and f_name.endswith(".wav"):
568
+ try:
569
+ os.remove(f_name)
570
+ except OSError:
571
+ pass # Ignore if file is already gone or locked
572
+ print("Starting Gradio interface locally with debug mode...")
573
+ iface.launch(debug=True, share=False) # share=False for local, Spaces handles public URL
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Machine Learning & Audio
2
+ torch
3
+ torchaudio
4
+ # For CPU, ensure correct PyTorch version if not using ZeroGPU:
5
+ # torch --index-url https://download.pytorch.org/whl/cpu
6
+ # For CUDA (ZeroGPU is CUDA-based, HF Spaces will handle this if ZeroGPU is selected)
7
+ # torch --index-url https://download.pytorch.org/whl/cu118
8
+ # (Check latest recommended CUDA version for ZeroGPU on HF docs)
9
+
10
+ transformers
11
+ librosa
12
+ soundfile
13
+ pydub
14
+ unidecode
15
+ inflect
16
+ huggingface_hub
17
+ sentencepiece # Often needed by tokenizers
18
+
19
+ # Gradio
20
+ gradio >=4.0.0 # Use a recent version of Gradio
21
+
22
+ # Other utilities
23
+ numpy
24
+ pandas # Though pandas is not explicitly used in the pipeline, it's in your imports