MoHamdyy commited on
Commit
ab9f341
·
1 Parent(s): b28da2b

changed app to gradio

Browse files
Files changed (1) hide show
  1. app.py +280 -299
app.py CHANGED
@@ -3,26 +3,26 @@ import re
3
  import time
4
  import random
5
  import numpy as np
6
- import pandas as pd
7
- import math
8
- import shutil
9
- import base64
10
 
11
  # Torch and Audio
12
  import torch
13
  import torch.nn as nn
14
- import torch.optim as optim
15
- from torch.utils.data import Dataset, DataLoader
16
  import torch.nn.functional as F
17
  import torchaudio
18
- import librosa
19
- import librosa.display
20
 
21
  # Text and Audio Processing
22
  from unidecode import unidecode
23
- from inflect import engine
24
- import pydub
25
- import soundfile as sf
26
 
27
  # Transformers
28
  from transformers import (
@@ -30,22 +30,27 @@ from transformers import (
30
  MarianTokenizer, MarianMTModel,
31
  )
32
 
33
- # API Server
34
- from fastapi import FastAPI, UploadFile, File
35
- from fastapi.middleware.cors import CORSMiddleware
36
- from fastapi.staticfiles import StaticFiles # <--- ADD THIS IMPORT
37
 
 
 
 
38
 
 
 
 
 
 
 
 
39
 
40
- # Part 2: TTS Model Components (from your notebook)
41
-
42
-
43
- # Hyperparameters
44
  class Hyperparams:
45
  seed = 42
46
- # We won't use these dataset paths, but keep them for hp object integrity
47
- csv_path = "path/to/metadata.csv"
48
- wav_path = "path/to/wavs"
49
  symbols = [
50
  'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
51
  'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
@@ -59,7 +64,7 @@ class Hyperparams:
59
  win_length = int(n_fft/2.0)
60
  mel_freq = 128
61
  max_mel_time = 1024
62
- power = 2.0
63
  text_num_embeddings = 2*len(symbols)
64
  embedding_size = 256
65
  encoder_embedding_size = 512
@@ -67,17 +72,16 @@ class Hyperparams:
67
  postnet_embedding_size = 1024
68
  encoder_kernel_size = 3
69
  postnet_kernel_size = 5
70
- ampl_multiplier = 10.0
71
- ampl_amin = 1e-10
72
- db_multiplier = 1.0
73
- ampl_ref = 1.0
74
- ampl_power = 1.0
75
- max_db = 100
76
- scale_db = 10
77
-
78
  hp = Hyperparams()
79
 
80
- # Text to Sequence
81
  symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
82
  def text_to_seq(text):
83
  text = text.lower()
@@ -89,27 +93,19 @@ def text_to_seq(text):
89
  seq.append(symbol_to_id["EOS"])
90
  return torch.IntTensor(seq)
91
 
92
- # Audio Processing
93
- spec_transform = torchaudio.transforms.Spectrogram(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=hp.power)
94
- mel_scale_transform = torchaudio.transforms.MelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft)
95
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
96
  mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
97
- griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length).to(DEVICE)
98
-
99
- def pow_to_db_mel_spec(mel_spec):
100
- mel_spec = torchaudio.functional.amplitude_to_DB(mel_spec, multiplier=hp.ampl_multiplier, amin=hp.ampl_amin, db_multiplier=hp.db_multiplier, top_db=hp.max_db)
101
- mel_spec = mel_spec/hp.scale_db
102
- return mel_spec
103
 
104
  def db_to_power_mel_spec(mel_spec):
105
- mel_spec = mel_spec*hp.scale_db
106
- mel_spec = torchaudio.functional.DB_to_amplitude(mel_spec, ref=hp.ampl_ref, power=hp.ampl_power)
107
- return mel_spec
108
-
109
- def inverse_mel_spec_to_wav(mel_spec):
110
- power_mel_spec = db_to_power_mel_spec(mel_spec.to(DEVICE))
111
- spectrogram = mel_inverse_transform(power_mel_spec)
112
- pseudo_wav = griffnlim_transform(spectrogram)
 
113
  return pseudo_wav
114
 
115
  def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.BoolTensor:
@@ -117,287 +113,266 @@ def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> to
117
  range_tensor = ones.cumsum(dim=1)
118
  return sequence_lengths.unsqueeze(1) >= range_tensor
119
 
120
- # --- TransformerTTS Model Architecture (Copied from notebook)
121
- class EncoderBlock(nn.Module):
122
  def __init__(self):
123
  super(EncoderBlock, self).__init__()
124
- self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
125
- self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
126
  self.dropout_1 = torch.nn.Dropout(0.1)
127
- self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
128
  self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
129
  self.dropout_2 = torch.nn.Dropout(0.1)
130
  self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
131
  self.dropout_3 = torch.nn.Dropout(0.1)
132
  def forward(self, x, attn_mask=None, key_padding_mask=None):
133
- x_out = self.norm_1(x)
134
- x_out, _ = self.attn(query=x_out, key=x_out, value=x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
135
- x_out = self.dropout_1(x_out)
136
- x = x + x_out
137
- x_out = self.norm_2(x)
138
- x_out = self.linear_1(x_out)
139
- x_out = F.relu(x_out)
140
- x_out = self.dropout_2(x_out)
141
- x_out = self.linear_2(x_out)
142
- x_out = self.dropout_3(x_out)
143
- x = x + x_out
144
- return x
145
-
146
- class DecoderBlock(nn.Module):
147
  def __init__(self):
148
  super(DecoderBlock, self).__init__()
149
- self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
150
- self.self_attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
151
  self.dropout_1 = torch.nn.Dropout(0.1)
152
- self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
153
- self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
154
  self.dropout_2 = torch.nn.Dropout(0.1)
155
- self.norm_3 = nn.LayerNorm(normalized_shape=hp.embedding_size)
156
  self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
157
  self.dropout_3 = torch.nn.Dropout(0.1)
158
  self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
159
  self.dropout_4 = torch.nn.Dropout(0.1)
160
  def forward(self, x, memory, x_attn_mask=None, x_key_padding_mask=None, memory_attn_mask=None, memory_key_padding_mask=None):
161
- x_out, _ = self.self_attn(query=x, key=x, value=x, attn_mask=x_attn_mask, key_padding_mask=x_key_padding_mask)
162
- x_out = self.dropout_1(x_out)
163
- x = self.norm_1(x + x_out)
164
- x_out, _ = self.attn(query=x, key=memory, value=memory, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask)
165
- x_out = self.dropout_2(x_out)
166
- x = self.norm_2(x + x_out)
167
- x_out = self.linear_1(x)
168
- x_out = F.relu(x_out)
169
- x_out = self.dropout_3(x_out)
170
- x_out = self.linear_2(x_out)
171
- x_out = self.dropout_4(x_out)
172
- x = self.norm_3(x + x_out)
173
- return x
174
-
175
- class EncoderPreNet(nn.Module):
176
  def __init__(self):
177
  super(EncoderPreNet, self).__init__()
178
- self.embedding = nn.Embedding(num_embeddings=hp.text_num_embeddings, embedding_dim=hp.encoder_embedding_size)
179
  self.linear_1 = nn.Linear(hp.encoder_embedding_size, hp.encoder_embedding_size)
180
  self.linear_2 = nn.Linear(hp.encoder_embedding_size, hp.embedding_size)
181
- self.conv_1 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
182
- self.bn_1 = nn.BatchNorm1d(hp.encoder_embedding_size)
183
- self.dropout_1 = torch.nn.Dropout(0.5)
184
- self.conv_2 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
185
- self.bn_2 = nn.BatchNorm1d(hp.encoder_embedding_size)
186
- self.dropout_2 = torch.nn.Dropout(0.5)
187
- self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
188
- self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size)
189
- self.dropout_3 = torch.nn.Dropout(0.5)
190
  def forward(self, text):
191
- x = self.embedding(text)
192
- x = self.linear_1(x)
193
- x = x.transpose(2, 1)
194
- x = self.conv_1(x)
195
- x = self.bn_1(x)
196
- x = F.relu(x)
197
- x = self.dropout_1(x)
198
- x = self.conv_2(x)
199
- x = self.bn_2(x)
200
- x = F.relu(x)
201
- x = self.dropout_2(x)
202
- x = self.conv_3(x)
203
- x = self.bn_3(x)
204
- x = F.relu(x)
205
- x = self.dropout_3(x)
206
- x = x.transpose(1, 2)
207
- x = self.linear_2(x)
208
- return x
209
-
210
- class PostNet(nn.Module):
211
  def __init__(self):
212
  super(PostNet, self).__init__()
213
- self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
214
- self.bn_1 = nn.BatchNorm1d(hp.postnet_embedding_size)
215
- self.dropout_1 = torch.nn.Dropout(0.5)
216
- self.conv_2 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
217
- self.bn_2 = nn.BatchNorm1d(hp.postnet_embedding_size)
218
- self.dropout_2 = torch.nn.Dropout(0.5)
219
- self.conv_3 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
220
- self.bn_3 = nn.BatchNorm1d(hp.postnet_embedding_size)
221
- self.dropout_3 = torch.nn.Dropout(0.5)
222
- self.conv_4 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
223
- self.bn_4 = nn.BatchNorm1d(hp.postnet_embedding_size)
224
- self.dropout_4 = torch.nn.Dropout(0.5)
225
- self.conv_5 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
226
- self.bn_5 = nn.BatchNorm1d(hp.postnet_embedding_size)
227
- self.dropout_5 = torch.nn.Dropout(0.5)
228
- self.conv_6 = nn.Conv1d(hp.postnet_embedding_size, hp.mel_freq, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
229
- self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
230
- self.dropout_6 = torch.nn.Dropout(0.5)
231
  def forward(self, x):
232
- x = x.transpose(2, 1)
233
- x = self.conv_1(x)
234
- x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
235
- x = self.conv_2(x)
236
- x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
237
- x = self.conv_3(x)
238
- x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
239
- x = self.conv_4(x)
240
- x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
241
- x = self.conv_5(x)
242
- x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
243
- x = self.conv_6(x)
244
- x = self.bn_6(x); x = self.dropout_6(x)
245
- x = x.transpose(1, 2)
246
- return x
247
-
248
- class DecoderPreNet(nn.Module):
249
  def __init__(self):
250
  super(DecoderPreNet, self).__init__()
251
  self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
252
  self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size)
253
  def forward(self, x):
254
- x = self.linear_1(x)
255
- x = F.relu(x)
256
- x = F.dropout(x, p=0.5, training=True)
257
- x = self.linear_2(x)
258
- x = F.relu(x)
259
- x = F.dropout(x, p=0.5, training=True)
260
  return x
261
 
262
- class TransformerTTS(nn.Module):
263
- def __init__(self, device=DEVICE):
264
  super(TransformerTTS, self).__init__()
265
  self.encoder_prenet = EncoderPreNet()
266
  self.decoder_prenet = DecoderPreNet()
267
  self.postnet = PostNet()
268
- self.pos_encoding = nn.Embedding(num_embeddings=hp.max_mel_time, embedding_dim=hp.embedding_size)
269
- self.encoder_block_1 = EncoderBlock()
270
- self.encoder_block_2 = EncoderBlock()
271
- self.encoder_block_3 = EncoderBlock()
272
- self.decoder_block_1 = DecoderBlock()
273
- self.decoder_block_2 = DecoderBlock()
274
- self.decoder_block_3 = DecoderBlock()
275
  self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
276
  self.linear_2 = nn.Linear(hp.embedding_size, 1)
277
- self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)
278
- def forward(self, text, text_len, mel, mel_len):
279
- N = text.shape[0]; S = text.shape[1]; TIME = mel.shape[1]
280
- self.src_key_padding_mask = torch.zeros((N, S), device=text.device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S), float("-inf"))
281
- self.src_mask = torch.zeros((S, S), device=text.device).masked_fill(torch.triu(torch.full((S, S), True, dtype=torch.bool), diagonal=1).to(text.device), float("-inf"))
282
- self.tgt_key_padding_mask = torch.zeros((N, TIME), device=mel.device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME), float("-inf"))
283
- self.tgt_mask = torch.zeros((TIME, TIME), device=mel.device).masked_fill(torch.triu(torch.full((TIME, TIME), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
284
- self.memory_mask = torch.zeros((TIME, S), device=mel.device).masked_fill(torch.triu(torch.full((TIME, S), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
 
 
 
 
 
 
 
285
  text_x = self.encoder_prenet(text)
286
- pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time).to(mel.device))
287
- S = text_x.shape[1]; text_x = text_x + pos_codes[:S]
 
288
  text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
289
  text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
290
  text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
291
  text_x = self.norm_memory(text_x)
292
- mel_x = self.decoder_prenet(mel); mel_x = mel_x + pos_codes[:TIME]
 
 
 
293
  mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
294
  mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
295
  mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
 
296
  mel_linear = self.linear_1(mel_x)
297
- mel_postnet = self.postnet(mel_linear)
298
- mel_postnet = mel_linear + mel_postnet
299
- stop_token = self.linear_2(mel_x)
300
- bool_mel_mask = self.tgt_key_padding_mask.ne(0).unsqueeze(-1).repeat(1, 1, hp.mel_freq)
301
- mel_linear = mel_linear.masked_fill(bool_mel_mask, 0)
302
- mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0)
303
- stop_token = stop_token.masked_fill(bool_mel_mask[:, :, 0].unsqueeze(-1), 1e3).squeeze(2)
 
 
 
 
 
304
  return mel_postnet, mel_linear, stop_token
305
 
306
- @torch.no_grad()
307
- def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=True):
308
- self.eval(); self.train(False)
309
- text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
 
 
 
 
310
  N = 1
311
- SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
312
  mel_padded = SOS
313
- mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
314
- stop_token_outputs = torch.FloatTensor([]).to(text.device)
315
- iters = range(max_length)
 
 
 
 
 
 
 
316
  for _ in iters:
317
- mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
 
 
 
 
318
  mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
319
- if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
 
 
 
 
320
  break
321
  else:
 
322
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
323
- mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
324
- return mel_postnet, stop_token_outputs
325
 
326
- # Part 3: Model Loading
327
-
 
 
 
 
 
328
 
329
- # IMPORTANT: These paths assume you have placed the downloaded models
330
- # into a 'models' subfolder in your project directory.
331
- # ---------------------------------
332
- # --- Part 3: Model Loading (from Hugging Face Hub)
333
- # ---------------------------------
334
 
335
- # IMPORTANT: Replace "your-username" with your Hugging Face username
336
- # and the model names with the ones you created on the Hub.
 
337
  TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
338
  ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
339
  MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
340
 
341
- # Helper function to download the TTS model file from the Hub
342
- from huggingface_hub import hf_hub_download
343
-
344
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
345
  print("Loading models from Hugging Face Hub to device:", DEVICE)
 
346
 
347
- # Load TTS Model from Hub
348
  try:
349
  print("Loading TTS model...")
350
- # Download the .pt file from its repo
351
  tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
352
  state = torch.load(tts_model_path, map_location=DEVICE)
353
- TTS_MODEL = TransformerTTS().to(DEVICE)
354
- # Check for the correct key in the state dictionary
355
- if "model" in state:
356
- TTS_MODEL.load_state_dict(state["model"])
357
- elif "state_dict" in state:
358
- TTS_MODEL.load_state_dict(state["state_dict"])
359
- else:
360
- TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
361
  TTS_MODEL.eval()
362
  print("TTS model loaded successfully.")
363
- except Exception as e:
364
- print(f"Error loading TTS model: {e}")
365
- TTS_MODEL = None
366
 
367
- # Load STT (Whisper) Model from Hub
368
  try:
369
  print("Loading STT (Whisper) model...")
370
  stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
371
  stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
372
  print("STT model loaded successfully.")
373
- except Exception as e:
374
- print(f"Error loading STT model: {e}")
375
- stt_processor = None
376
- stt_model = None
377
 
378
- # Load TTT (MarianMT) Model from Hub
379
  try:
380
  print("Loading TTT (MarianMT) model...")
381
  mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
382
  mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
383
  print("TTT model loaded successfully.")
384
- except Exception as e:
385
- print(f"Error loading TTT model: {e}")
386
- mt_tokenizer = None
387
- mt_model = None
388
-
389
-
390
-
391
- # Part 4: Full Pipeline Function
392
-
393
 
394
- def full_speech_translation_pipeline(audio_input_path: str):
 
395
  print(f"--- PIPELINE START: Processing {audio_input_path} ---")
396
- if audio_input_path is None or not os.path.exists(audio_input_path):
397
- msg = "Error: Audio file not provided or not found."
 
 
 
 
 
 
398
  print(msg)
399
- # Return empty/default values
400
- return "Error: No file", "", (hp.sr, np.array([]).astype(np.float32))
 
 
 
 
 
 
 
 
401
 
402
  # STT Stage
403
  arabic_transcript = "STT Error: Processing failed."
@@ -416,12 +391,16 @@ def full_speech_translation_pipeline(audio_input_path: str):
416
  generated_ids = stt_model.generate(inputs, forced_decoder_ids=forced_ids, max_length=448)
417
  arabic_transcript = stt_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
418
  print(f"STT Output: {arabic_transcript}")
 
419
  except Exception as e:
420
- print(f"STT Error: {e}")
 
 
421
 
422
  # TTT Stage
423
  english_translation = "TTT Error: Processing failed."
424
- if arabic_transcript and not arabic_transcript.startswith("STT Error"):
 
425
  try:
426
  print("TTT: Translating to English...")
427
  batch = mt_tokenizer(arabic_transcript, return_tensors="pt", padding=True).to(DEVICE)
@@ -429,72 +408,74 @@ def full_speech_translation_pipeline(audio_input_path: str):
429
  translated_ids = mt_model.generate(**batch, max_length=512)
430
  english_translation = mt_tokenizer.batch_decode(translated_ids, skip_special_tokens=True)[0].strip()
431
  print(f"TTT Output: {english_translation}")
 
432
  except Exception as e:
433
- print(f"TTT Error: {e}")
434
- else:
435
- english_translation = "(Skipped TTT due to STT failure)"
 
436
  print(english_translation)
 
 
 
437
 
438
  # TTS Stage
439
  synthesized_audio_np = np.array([]).astype(np.float32)
440
- if english_translation and not english_translation.startswith("TTT Error"):
441
  try:
442
  print("TTS: Synthesizing English speech...")
443
- sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
 
 
 
 
 
444
  generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-20, stop_token_threshold=0.5, with_tqdm=False)
445
 
446
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
447
- if generated_mel is not None and generated_mel.numel() > 0:
448
- mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
449
  audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
450
  synthesized_audio_np = audio_tensor.cpu().numpy()
451
  print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
 
 
 
452
  except Exception as e:
453
- print(f"TTS Error: {e}")
454
-
 
 
 
 
 
455
  print(f"--- PIPELINE END ---")
456
- return arabic_transcript, english_translation, (hp.sr, synthesized_audio_np)
457
-
458
-
459
- # Part 5: FastAPI Application
460
-
461
- app = FastAPI()
462
-
463
- # Allow Cross-Origin Resource Sharing (CORS) for your frontend
464
- app.add_middleware(
465
- CORSMiddleware,
466
- allow_origins=["*"], # Allows all origins
467
- allow_credentials=True,
468
- allow_methods=["*"], # Allows all methods
469
- allow_headers=["*"], # Allows all headers
 
 
 
 
 
 
 
470
  )
471
 
472
- @app.post("/process-speech/")
473
- async def create_upload_file(file: UploadFile = File(...)):
474
- # Save the uploaded file temporarily
475
- temp_path = f"/tmp/{file.filename}"
476
- with open(temp_path, "wb") as buffer:
477
- shutil.copyfileobj(file.file, buffer)
478
-
479
- # Run the full pipeline
480
- arabic, english, (sr, audio_np) = full_speech_translation_pipeline(temp_path)
481
-
482
- # Prepare the audio to be sent back as base64
483
- audio_base64 = ""
484
- if audio_np.size > 0:
485
- temp_wav_path = "/tmp/output.wav"
486
- sf.write(temp_wav_path, audio_np, sr)
487
- with open(temp_wav_path, "rb") as wav_file:
488
- audio_bytes = wav_file.read()
489
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
490
-
491
- # Return all results in a single JSON response
492
- return {
493
- "arabic_transcript": arabic,
494
- "english_translation": english,
495
- "audio_data": {
496
- "sample_rate": sr,
497
- "base64": audio_base64
498
- }
499
- }
500
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
3
  import time
4
  import random
5
  import numpy as np
6
+ import pandas as pd # Keep if hp or other parts use it, though not directly in pipeline
7
+ import math # Keep if hp or other parts use it
8
+ # import shutil # Not needed for Gradio file handling
9
+ # import base64 # Not needed for Gradio audio output
10
 
11
  # Torch and Audio
12
  import torch
13
  import torch.nn as nn
14
+ # import torch.optim as optim # Not needed for inference
15
+ # from torch.utils.data import Dataset, DataLoader # Not needed for inference
16
  import torch.nn.functional as F
17
  import torchaudio
18
+ # import librosa # Not strictly needed if not plotting in Gradio
19
+ # import librosa.display # Not strictly needed if not plotting in Gradio
20
 
21
  # Text and Audio Processing
22
  from unidecode import unidecode
23
+ from inflect import engine as inflect_engine_tts # Renamed to avoid conflict
24
+ # import pydub # Not needed for Gradio audio output
25
+ # import soundfile as sf # Gradio handles audio output directly
26
 
27
  # Transformers
28
  from transformers import (
 
30
  MarianTokenizer, MarianMTModel,
31
  )
32
 
33
+ # Gradio
34
+ import gradio as gr
35
+ from huggingface_hub import hf_hub_download # For downloading models
 
36
 
37
+ # --- Configuration & Device ---
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+ print(f"Using device: {DEVICE}")
40
 
41
+ torch.manual_seed(42)
42
+ np.random.seed(42)
43
+ random.seed(42)
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed_all(42)
46
+ torch.backends.cudnn.deterministic = True
47
+ torch.backends.cudnn.benchmark = False
48
 
49
+ # --- Hyperparams Class (VERBATIM from your notebook) ---
 
 
 
50
  class Hyperparams:
51
  seed = 42
52
+ csv_path = "path/to/metadata.csv" # Not used directly
53
+ wav_path = "path/to/wavs" # Not used directly
 
54
  symbols = [
55
  'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
56
  'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
 
64
  win_length = int(n_fft/2.0)
65
  mel_freq = 128
66
  max_mel_time = 1024
67
+ power = 2.0 # For spec_transform if used, not directly by inverse_mel
68
  text_num_embeddings = 2*len(symbols)
69
  embedding_size = 256
70
  encoder_embedding_size = 512
 
72
  postnet_embedding_size = 1024
73
  encoder_kernel_size = 3
74
  postnet_kernel_size = 5
75
+ ampl_multiplier = 10.0 # For pow_to_db_mel_spec
76
+ ampl_amin = 1e-10 # For pow_to_db_mel_spec
77
+ db_multiplier = 1.0 # For pow_to_db_mel_spec
78
+ ampl_ref = 1.0 # For db_to_power_mel_spec
79
+ ampl_power = 1.0 # For db_to_power_mel_spec
80
+ max_db = 100 # For pow_to_db_mel_spec
81
+ scale_db = 10 # For pow_to_db_mel_spec & db_to_power_mel_spec
 
82
  hp = Hyperparams()
83
 
84
+ # --- TTS Text & Audio Processing (VERBATIM from your notebook) ---
85
  symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
86
  def text_to_seq(text):
87
  text = text.lower()
 
93
  seq.append(symbol_to_id["EOS"])
94
  return torch.IntTensor(seq)
95
 
 
 
 
 
96
  mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
97
+ griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=1.0).to(DEVICE) # Explicit power=1.0 for magnitude
 
 
 
 
 
98
 
99
  def db_to_power_mel_spec(mel_spec):
100
+ mel_spec_scaled = mel_spec * hp.scale_db # Corrected: use a different variable name
101
+ mel_spec_amp = torchaudio.functional.DB_to_amplitude(mel_spec_scaled, ref=hp.ampl_ref, power=hp.ampl_power)
102
+ return mel_spec_amp
103
+
104
+ def inverse_mel_spec_to_wav(mel_spec): # Expects [Freq, Time]
105
+ mel_spec_on_device = mel_spec.to(DEVICE)
106
+ power_mel_spec = db_to_power_mel_spec(mel_spec_on_device) # This is amplitude
107
+ spectrogram = mel_inverse_transform(power_mel_spec) # Amplitude mel to linear amplitude
108
+ pseudo_wav = griffnlim_transform(spectrogram) # Linear amplitude to wav
109
  return pseudo_wav
110
 
111
  def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.BoolTensor:
 
113
  range_tensor = ones.cumsum(dim=1)
114
  return sequence_lengths.unsqueeze(1) >= range_tensor
115
 
116
+ # --- TransformerTTS Model Architecture (VERBATIM from your FastAPI code) ---
117
+ class EncoderBlock(nn.Module): # VERBATIM
118
  def __init__(self):
119
  super(EncoderBlock, self).__init__()
120
+ self.norm_1 = nn.LayerNorm(hp.embedding_size)
121
+ self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
122
  self.dropout_1 = torch.nn.Dropout(0.1)
123
+ self.norm_2 = nn.LayerNorm(hp.embedding_size)
124
  self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
125
  self.dropout_2 = torch.nn.Dropout(0.1)
126
  self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
127
  self.dropout_3 = torch.nn.Dropout(0.1)
128
  def forward(self, x, attn_mask=None, key_padding_mask=None):
129
+ x_out = self.norm_1(x); x_out, _ = self.attn(x_out, x_out, x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
130
+ x_out = self.dropout_1(x_out); x = x + x_out
131
+ x_out = self.norm_2(x) ; x_out = self.linear_1(x_out); x_out = F.relu(x_out); x_out = self.dropout_2(x_out)
132
+ x_out = self.linear_2(x_out); x_out = self.dropout_3(x_out)
133
+ x = x + x_out; return x
134
+
135
+ class DecoderBlock(nn.Module): # VERBATIM
 
 
 
 
 
 
 
136
  def __init__(self):
137
  super(DecoderBlock, self).__init__()
138
+ self.norm_1 = nn.LayerNorm(hp.embedding_size)
139
+ self.self_attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
140
  self.dropout_1 = torch.nn.Dropout(0.1)
141
+ self.norm_2 = nn.LayerNorm(hp.embedding_size)
142
+ self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
143
  self.dropout_2 = torch.nn.Dropout(0.1)
144
+ self.norm_3 = nn.LayerNorm(hp.embedding_size)
145
  self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
146
  self.dropout_3 = torch.nn.Dropout(0.1)
147
  self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
148
  self.dropout_4 = torch.nn.Dropout(0.1)
149
  def forward(self, x, memory, x_attn_mask=None, x_key_padding_mask=None, memory_attn_mask=None, memory_key_padding_mask=None):
150
+ x_out, _ = self.self_attn(x, x, x, attn_mask=x_attn_mask, key_padding_mask=x_key_padding_mask)
151
+ x_out = self.dropout_1(x_out); x = self.norm_1(x + x_out)
152
+ x_out, _ = self.attn(x, memory, memory, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask)
153
+ x_out = self.dropout_2(x_out); x = self.norm_2(x + x_out)
154
+ x_out = self.linear_1(x); x_out = F.relu(x_out); x_out = self.dropout_3(x_out)
155
+ x_out = self.linear_2(x_out); x_out = self.dropout_4(x_out)
156
+ x = self.norm_3(x + x_out); return x
157
+
158
+ class EncoderPreNet(nn.Module): # VERBATIM
 
 
 
 
 
 
159
  def __init__(self):
160
  super(EncoderPreNet, self).__init__()
161
+ self.embedding = nn.Embedding(hp.text_num_embeddings, hp.encoder_embedding_size)
162
  self.linear_1 = nn.Linear(hp.encoder_embedding_size, hp.encoder_embedding_size)
163
  self.linear_2 = nn.Linear(hp.encoder_embedding_size, hp.embedding_size)
164
+ self.conv_1 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
165
+ self.bn_1 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_1 = nn.Dropout(0.5)
166
+ self.conv_2 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
167
+ self.bn_2 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_2 = nn.Dropout(0.5)
168
+ self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
169
+ self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_3 = nn.Dropout(0.5)
 
 
 
170
  def forward(self, text):
171
+ x = self.embedding(text); x = self.linear_1(x); x = x.transpose(2,1)
172
+ x = self.conv_1(x); x = self.bn_1(x); x = F.relu(x); x = self.dropout_1(x)
173
+ x = self.conv_2(x); x = self.bn_2(x); x = F.relu(x); x = self.dropout_2(x)
174
+ x = self.conv_3(x); x = self.bn_3(x); x = F.relu(x); x = self.dropout_3(x)
175
+ x = x.transpose(1,2); x = self.linear_2(x); return x
176
+
177
+ class PostNet(nn.Module): # VERBATIM
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def __init__(self):
179
  super(PostNet, self).__init__()
180
+ self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
181
+ self.bn_1 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_1 = nn.Dropout(0.5)
182
+ self.conv_2 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
183
+ self.bn_2 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_2 = nn.Dropout(0.5)
184
+ self.conv_3 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
185
+ self.bn_3 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_3 = nn.Dropout(0.5)
186
+ self.conv_4 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
187
+ self.bn_4 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_4 = nn.Dropout(0.5)
188
+ self.conv_5 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
189
+ self.bn_5 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_5 = nn.Dropout(0.5)
190
+ self.conv_6 = nn.Conv1d(hp.postnet_embedding_size, hp.mel_freq, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
191
+ self.bn_6 = nn.BatchNorm1d(hp.mel_freq); self.dropout_6 = nn.Dropout(0.5)
 
 
 
 
 
 
192
  def forward(self, x):
193
+ x_orig = x; x = x.transpose(2,1)
194
+ x = self.conv_1(x); x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
195
+ x = self.conv_2(x); x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
196
+ x = self.conv_3(x); x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
197
+ x = self.conv_4(x); x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
198
+ x = self.conv_5(x); x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
199
+ x = self.conv_6(x); x = self.bn_6(x); x = self.dropout_6(x)
200
+ x = x.transpose(1,2); return x # Original postnet in repo is residual, added in TransformerTTS.forward
201
+
202
+ class DecoderPreNet(nn.Module): # VERBATIM
 
 
 
 
 
 
 
203
  def __init__(self):
204
  super(DecoderPreNet, self).__init__()
205
  self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
206
  self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size)
207
  def forward(self, x):
208
+ x = self.linear_1(x); x = F.relu(x)
209
+ x = F.dropout(x, p=0.5, training=True) # Dropout always on
210
+ x = self.linear_2(x); x = F.relu(x)
211
+ x = F.dropout(x, p=0.5, training=True) # Dropout always on
 
 
212
  return x
213
 
214
+ class TransformerTTS(nn.Module): # VERBATIM (init had device=DEVICE, now model is moved after init)
215
+ def __init__(self): # Removed device=DEVICE from here
216
  super(TransformerTTS, self).__init__()
217
  self.encoder_prenet = EncoderPreNet()
218
  self.decoder_prenet = DecoderPreNet()
219
  self.postnet = PostNet()
220
+ self.pos_encoding = nn.Embedding(hp.max_mel_time, hp.embedding_size)
221
+ self.encoder_block_1 = EncoderBlock(); self.encoder_block_2 = EncoderBlock(); self.encoder_block_3 = EncoderBlock()
222
+ self.decoder_block_1 = DecoderBlock(); self.decoder_block_2 = DecoderBlock(); self.decoder_block_3 = DecoderBlock()
 
 
 
 
223
  self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
224
  self.linear_2 = nn.Linear(hp.embedding_size, 1)
225
+ self.norm_memory = nn.LayerNorm(hp.embedding_size)
226
+ # Mask attributes will be set in forward pass, as per your code
227
+ self.src_key_padding_mask = None; self.src_mask = None
228
+ self.tgt_key_padding_mask = None; self.tgt_mask = None; self.memory_mask = None
229
+
230
+ def forward(self, text, text_len, mel, mel_len): # VERBATIM
231
+ N = text.shape[0]; S_text_in = text.shape[1]; TIME_mel_in = mel.shape[1]
232
+ current_device = text.device
233
+
234
+ self.src_key_padding_mask = torch.zeros((N, S_text_in), device=current_device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S_text_in), float("-inf"))
235
+ self.src_mask = torch.zeros((S_text_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((S_text_in, S_text_in), True, dtype=torch.bool, device=current_device), diagonal=1), float("-inf"))
236
+ self.tgt_key_padding_mask = torch.zeros((N, TIME_mel_in), device=current_device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME_mel_in), float("-inf"))
237
+ self.tgt_mask = torch.zeros((TIME_mel_in, TIME_mel_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, TIME_mel_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))
238
+ self.memory_mask = torch.zeros((TIME_mel_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, S_text_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))
239
+
240
  text_x = self.encoder_prenet(text)
241
+ pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
242
+ S_text_processed = text_x.shape[1]; text_x = text_x + pos_codes[:S_text_processed] # Use actual S after prenet
243
+
244
  text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
245
  text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
246
  text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
247
  text_x = self.norm_memory(text_x)
248
+
249
+ mel_x = self.decoder_prenet(mel);
250
+ TIME_mel_processed = mel_x.shape[1]; mel_x = mel_x + pos_codes[:TIME_mel_processed] # Use actual T after prenet
251
+
252
  mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
253
  mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
254
  mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
255
+
256
  mel_linear = self.linear_1(mel_x)
257
+ postnet_residual_out = self.postnet(mel_linear) # PostNet output
258
+ mel_postnet = mel_linear + postnet_residual_out # Add residual
259
+ stop_token = self.linear_2(mel_x) # (N, TIME, 1)
260
+
261
+ # Masking output based on padding
262
+ # self.tgt_key_padding_mask is -inf for padded, 0 for unpadded.
263
+ # .ne(0) makes it True for padded, False for unpadded. This is correct for masked_fill.
264
+ bool_mel_padding_mask = self.tgt_key_padding_mask.ne(0)
265
+
266
+ mel_linear = mel_linear.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_linear), 0)
267
+ mel_postnet = mel_postnet.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_postnet), 0)
268
+ stop_token = stop_token.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(stop_token), 1e3).squeeze(2)
269
  return mel_postnet, mel_linear, stop_token
270
 
271
+ @torch.no_grad() # VERBATIM from your FastAPI code (with .item() fix)
272
+ def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=False): # with_tqdm was False in pipeline call
273
+ self.eval(); self.train(False) # As per your original
274
+ model_device = next(self.parameters()).device
275
+
276
+ text_on_device = text.to(model_device)
277
+ text_lengths = torch.tensor([text_on_device.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure text_lengths is also 2D [1,1] or [1]
278
+
279
  N = 1
280
+ SOS = torch.zeros((N, 1, hp.mel_freq), device=model_device)
281
  mel_padded = SOS
282
+ mel_lengths = torch.tensor([1],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure mel_lengths is also 2D [1,1] or [1]
283
+
284
+ stop_token_outputs = torch.FloatTensor([]).to(model_device) # text.device might be CPU if text wasn't on device
285
+
286
+ # Use local tqdm to avoid conflict if tqdm is imported elsewhere
287
+ from tqdm import tqdm as tqdm_local
288
+ iters = tqdm_local(range(max_length), desc="TTS Inference") if with_tqdm else range(max_length)
289
+
290
+ final_mel_postnet_output = SOS # To store the output from the last forward pass
291
+
292
  for _ in iters:
293
+ # mel_postnet is (N, T_current_input_len, Freq)
294
+ mel_postnet, mel_linear, stop_token = self(text_on_device, text_lengths, mel_padded, mel_lengths)
295
+ final_mel_postnet_output = mel_postnet # This is the full sequence predicted in this step
296
+
297
+ # Append last frame of mel_postnet for next input step
298
  mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
299
+ mel_lengths = torch.tensor([mel_padded.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device)
300
+
301
+ # stop_token is (N, T_current_input_len)
302
+ # Check stop condition for the last frame of the input sequence
303
+ if (torch.sigmoid(stop_token[:, -1].squeeze()) > stop_token_threshold).item():
304
  break
305
  else:
306
+ # stop_token[:, -1:] is (N, 1)
307
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
 
 
308
 
309
+ # final_mel_postnet_output contains SOS. Strip it.
310
+ if final_mel_postnet_output.shape[1] > 1: # If more than just SOS frame
311
+ mel_to_return = final_mel_postnet_output[:, 1:, :]
312
+ else: # Only SOS was processed, or nothing
313
+ mel_to_return = torch.empty((N, 0, hp.mel_freq), device=model_device)
314
+ if mel_to_return.shape[1] == 0: # ensure stop_token_outputs is also empty
315
+ stop_token_outputs = torch.empty_like(stop_token_outputs[:,:0])
316
 
 
 
 
 
 
317
 
318
+ return mel_to_return, stop_token_outputs
319
+
320
+ # --- Part 3: Model Loading (from Hugging Face Hub - VERBATIM from your FastAPI code) ---
321
  TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
322
  ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
323
  MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
324
 
 
 
 
 
325
  print("Loading models from Hugging Face Hub to device:", DEVICE)
326
+ TTS_MODEL = None; stt_processor = None; stt_model = None; mt_tokenizer = None; mt_model = None
327
 
 
328
  try:
329
  print("Loading TTS model...")
 
330
  tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
331
  state = torch.load(tts_model_path, map_location=DEVICE)
332
+ TTS_MODEL = TransformerTTS().to(DEVICE) # Create instance then move to DEVICE
333
+ if "model" in state: TTS_MODEL.load_state_dict(state["model"])
334
+ elif "state_dict" in state: TTS_MODEL.load_state_dict(state["state_dict"])
335
+ else: TTS_MODEL.load_state_dict(state)
 
 
 
 
336
  TTS_MODEL.eval()
337
  print("TTS model loaded successfully.")
338
+ except Exception as e: print(f"Error loading TTS model: {e}")
 
 
339
 
 
340
  try:
341
  print("Loading STT (Whisper) model...")
342
  stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
343
  stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
344
  print("STT model loaded successfully.")
345
+ except Exception as e: print(f"Error loading STT model: {e}")
 
 
 
346
 
 
347
  try:
348
  print("Loading TTT (MarianMT) model...")
349
  mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
350
  mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
351
  print("TTT model loaded successfully.")
352
+ except Exception as e: print(f"Error loading TTT model: {e}")
 
 
 
 
 
 
 
 
353
 
354
+ # --- Part 4: Full Pipeline Function (VERBATIM from your FastAPI code, adapted for Gradio output) ---
355
+ def full_speech_translation_pipeline_gradio(audio_input_path: str): # Renamed for clarity
356
  print(f"--- PIPELINE START: Processing {audio_input_path} ---")
357
+ # Check if models are loaded
358
+ if not all([stt_processor, stt_model, mt_tokenizer, mt_model, TTS_MODEL]):
359
+ error_msg = "One or more models failed to load. Please check logs."
360
+ print(error_msg)
361
+ return error_msg, error_msg, (hp.sr, np.array([]).astype(np.float32))
362
+
363
+ if audio_input_path is None: # Gradio provides a path for uploaded/recorded audio
364
+ msg = "Error: No audio input received by Gradio."
365
  print(msg)
366
+ return msg, "", (hp.sr, np.array([]).astype(np.float32))
367
+
368
+ if not os.path.exists(audio_input_path):
369
+ # This case might happen if Gradio passes a temp path that gets cleaned up too quickly,
370
+ # or if there's an issue with how Gradio handles file paths.
371
+ # For Gradio `type="filepath"`, the path should be valid.
372
+ msg = f"Error: Audio file path provided by Gradio does not exist: {audio_input_path}"
373
+ print(msg)
374
+ return msg, "", (hp.sr, np.array([]).astype(np.float32))
375
+
376
 
377
  # STT Stage
378
  arabic_transcript = "STT Error: Processing failed."
 
391
  generated_ids = stt_model.generate(inputs, forced_decoder_ids=forced_ids, max_length=448)
392
  arabic_transcript = stt_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
393
  print(f"STT Output: {arabic_transcript}")
394
+ if not arabic_transcript: arabic_transcript = "(STT: No speech detected or empty transcript)"
395
  except Exception as e:
396
+ print(f"STT Error: {e}"); import traceback; traceback.print_exc()
397
+ arabic_transcript = f"STT Error: {e}"
398
+
399
 
400
  # TTT Stage
401
  english_translation = "TTT Error: Processing failed."
402
+ tts_status_message = "" # For appending TTS status to English text
403
+ if arabic_transcript and not arabic_transcript.startswith("STT Error") and not arabic_transcript.startswith("(STT:"):
404
  try:
405
  print("TTT: Translating to English...")
406
  batch = mt_tokenizer(arabic_transcript, return_tensors="pt", padding=True).to(DEVICE)
 
408
  translated_ids = mt_model.generate(**batch, max_length=512)
409
  english_translation = mt_tokenizer.batch_decode(translated_ids, skip_special_tokens=True)[0].strip()
410
  print(f"TTT Output: {english_translation}")
411
+ if not english_translation: english_translation = "(TTT: Empty translation)"
412
  except Exception as e:
413
+ print(f"TTT Error: {e}"); import traceback; traceback.print_exc()
414
+ english_translation = f"TTT Error: {e}"
415
+ elif arabic_transcript.startswith("STT Error") or arabic_transcript.startswith("(STT:"):
416
+ english_translation = "(Skipped TTT due to STT issue)"
417
  print(english_translation)
418
+ else: # Should not happen if STT produces some output
419
+ english_translation = "(Skipped TTT: Unknown STT state)"
420
+
421
 
422
  # TTS Stage
423
  synthesized_audio_np = np.array([]).astype(np.float32)
424
+ if english_translation and not english_translation.startswith("TTT Error") and not english_translation.startswith("(Skipped") and not english_translation.startswith("(TTT:"):
425
  try:
426
  print("TTS: Synthesizing English speech...")
427
+ sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE) # Ensure input is on TTS_MODEL's device
428
+
429
+ # Make sure TTS_MODEL is on the correct device before inference
430
+ TTS_MODEL.to(DEVICE) # Redundant if already done, but safe
431
+ TTS_MODEL.eval() # Ensure eval mode
432
+
433
  generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-20, stop_token_threshold=0.5, with_tqdm=False)
434
 
435
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
436
+ if generated_mel is not None and generated_mel.numel() > 0 and generated_mel.shape[1] > 0: # Check if time dimension has frames
437
+ mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1) # [F, T]
438
  audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
439
  synthesized_audio_np = audio_tensor.cpu().numpy()
440
  print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
441
+ else:
442
+ tts_status_message = "(TTS Error: Empty mel generated)"
443
+ print(tts_status_message)
444
  except Exception as e:
445
+ print(f"TTS Error: {e}"); import traceback; traceback.print_exc()
446
+ tts_status_message = f"(TTS Error: {e})"
447
+ elif english_translation.startswith("TTT Error") or english_translation.startswith("(Skipped") or english_translation.startswith("(TTT:"):
448
+ tts_status_message = "(Skipped TTS due to TTT/Input issue)"
449
+ else: # Should not happen if TTT produces some output
450
+ tts_status_message = "(Skipped TTS: Unknown TTT state)"
451
+
452
  print(f"--- PIPELINE END ---")
453
+ # Combine English translation with any TTS status message
454
+ final_english_display = english_translation
455
+ if tts_status_message:
456
+ final_english_display += f" {tts_status_message}"
457
+
458
+ return arabic_transcript, final_english_display.strip(), (hp.sr, synthesized_audio_np)
459
+
460
+ # --- Part 5: Gradio Interface ---
461
+ print("Setting up Gradio interface...")
462
+ demo = gr.Interface(
463
+ fn=full_speech_translation_pipeline_gradio,
464
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Arabic Speech Input"),
465
+ outputs=[
466
+ gr.Textbox(label="Arabic Transcript (STT)"),
467
+ gr.Textbox(label="English Translation & TTS Status"),
468
+ gr.Audio(label="Synthesized English Speech (TTS)", type="numpy") # type="numpy" expects (sr, data)
469
+ ],
470
+ title="Arabic Speech-to-Text -> Translation -> English Text-to-Speech",
471
+ description="Upload an Arabic audio file or record from your microphone. The system will transcribe it to Arabic, translate it to English, and then synthesize the English text into audible speech.",
472
+ allow_flagging="never",
473
+ examples=[["/kaggle/input/testtt/test_audio.ogg"]] if os.path.exists("/kaggle/input/testtt/test_audio.ogg") else None # Optional example
474
  )
475
 
476
+ if __name__ == "__main__":
477
+ print("Launching Gradio app...")
478
+ # When running on Hugging Face Spaces, HF handles the launch.
479
+ # For local testing, you might need a specific host/port.
480
+ # HF Spaces will look for a `demo.launch()` or `iface.launch()`
481
+ demo.launch(debug=True) # debug=True for more detailed Gradio logs