File size: 27,785 Bytes
9d962cb
 
 
 
 
ab9f341
 
 
 
9d962cb
 
 
 
ab9f341
 
9d962cb
 
ab9f341
 
9d962cb
 
 
ab9f341
 
 
9d962cb
 
 
 
 
 
 
ab9f341
 
 
9d962cb
ab9f341
 
 
9d962cb
ab9f341
 
 
 
 
 
 
9d962cb
ab9f341
9d962cb
 
ab9f341
 
9d962cb
 
 
 
 
 
 
 
 
 
 
 
 
ab9f341
9d962cb
 
 
 
 
 
 
ab9f341
 
 
 
 
 
 
9d962cb
 
ab9f341
9d962cb
 
 
 
 
 
 
 
 
 
 
 
ab9f341
9d962cb
 
ab9f341
 
 
 
 
 
 
 
 
9d962cb
 
 
 
 
 
 
ab9f341
 
9d962cb
 
ab9f341
 
9d962cb
ab9f341
9d962cb
 
 
 
 
ab9f341
 
 
 
 
 
 
9d962cb
 
ab9f341
 
9d962cb
ab9f341
 
9d962cb
ab9f341
9d962cb
 
 
 
 
ab9f341
 
 
 
 
 
 
 
 
9d962cb
 
ab9f341
9d962cb
 
ab9f341
 
 
 
 
 
9d962cb
ab9f341
 
 
 
 
 
 
9d962cb
 
ab9f341
 
 
 
 
 
 
 
 
 
 
 
9d962cb
ab9f341
 
 
 
 
 
 
 
 
 
9d962cb
 
 
 
 
ab9f341
 
 
 
9d962cb
 
ab9f341
 
9d962cb
 
 
 
ab9f341
 
 
9d962cb
 
ab9f341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d962cb
ab9f341
 
 
9d962cb
 
 
 
ab9f341
 
 
 
9d962cb
 
 
ab9f341
9d962cb
ab9f341
 
 
 
 
 
 
 
 
 
 
 
9d962cb
 
ab9f341
 
 
 
 
 
 
 
9d962cb
ab9f341
9d962cb
ab9f341
 
 
 
 
 
 
 
 
 
9d962cb
ab9f341
 
 
 
 
9d962cb
ab9f341
 
 
 
 
9d962cb
 
ab9f341
9d962cb
 
ab9f341
 
 
 
 
 
 
9d962cb
 
ab9f341
 
 
b28da2b
 
 
9d962cb
 
ab9f341
9d962cb
 
 
 
 
ab9f341
 
 
 
9d962cb
 
ab9f341
9d962cb
 
 
 
 
 
ab9f341
9d962cb
 
 
 
 
 
ab9f341
9d962cb
ab9f341
 
9d962cb
ab9f341
 
 
 
 
 
 
 
9d962cb
ab9f341
 
 
 
 
 
 
 
 
 
9d962cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab9f341
9d962cb
ab9f341
 
 
9d962cb
 
 
ab9f341
 
9d962cb
 
 
 
 
 
 
ab9f341
9d962cb
ab9f341
 
 
 
9d962cb
ab9f341
 
 
9d962cb
 
 
ab9f341
9d962cb
 
ab9f341
 
 
 
 
 
9d962cb
 
 
ab9f341
 
9d962cb
 
 
ab9f341
 
 
9d962cb
ab9f341
 
 
 
 
 
 
9d962cb
ab9f341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d962cb
4a2c2a8
ab9f341
 
 
 
 
4a2c2a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
import os
import re
import time
import random
import numpy as np
import pandas as pd # Keep if hp or other parts use it, though not directly in pipeline
import math # Keep if hp or other parts use it
# import shutil # Not needed for Gradio file handling
# import base64 # Not needed for Gradio audio output

# Torch and Audio
import torch
import torch.nn as nn
# import torch.optim as optim # Not needed for inference
# from torch.utils.data import Dataset, DataLoader # Not needed for inference
import torch.nn.functional as F
import torchaudio
# import librosa # Not strictly needed if not plotting in Gradio
# import librosa.display # Not strictly needed if not plotting in Gradio

# Text and Audio Processing
from unidecode import unidecode
from inflect import engine as inflect_engine_tts # Renamed to avoid conflict
# import pydub # Not needed for Gradio audio output
# import soundfile as sf # Gradio handles audio output directly

# Transformers
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    MarianTokenizer, MarianMTModel,
)

# Gradio
import gradio as gr
from huggingface_hub import hf_hub_download # For downloading models

# --- Configuration & Device ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Hyperparams Class (VERBATIM from your notebook) ---
class Hyperparams:
  seed = 42
  csv_path = "path/to/metadata.csv" # Not used directly
  wav_path = "path/to/wavs" # Not used directly
  symbols = [
    'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 
    'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 
    't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü', 
    '’', '“', '”' 
  ]
  sr = 22050
  n_fft = 2048
  n_stft = int((n_fft//2) + 1)
  hop_length = int(n_fft/8.0)
  win_length = int(n_fft/2.0)
  mel_freq = 128
  max_mel_time = 1024
  power = 2.0 # For spec_transform if used, not directly by inverse_mel
  text_num_embeddings = 2*len(symbols)  
  embedding_size = 256
  encoder_embedding_size = 512 
  dim_feedforward = 1024
  postnet_embedding_size = 1024
  encoder_kernel_size = 3
  postnet_kernel_size = 5
  ampl_multiplier = 10.0 # For pow_to_db_mel_spec
  ampl_amin = 1e-10 # For pow_to_db_mel_spec
  db_multiplier = 1.0 # For pow_to_db_mel_spec
  ampl_ref = 1.0 # For db_to_power_mel_spec
  ampl_power = 1.0 # For db_to_power_mel_spec
  max_db = 100 # For pow_to_db_mel_spec
  scale_db = 10 # For pow_to_db_mel_spec & db_to_power_mel_spec
hp = Hyperparams()

# --- TTS Text & Audio Processing (VERBATIM from your notebook) ---
symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
def text_to_seq(text):
  text = text.lower()
  seq = []
  for s in text:
    _id = symbol_to_id.get(s, None)
    if _id is not None:
      seq.append(_id)
  seq.append(symbol_to_id["EOS"])
  return torch.IntTensor(seq)

mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=1.0).to(DEVICE) # Explicit power=1.0 for magnitude

def db_to_power_mel_spec(mel_spec):
  mel_spec_scaled = mel_spec * hp.scale_db # Corrected: use a different variable name
  mel_spec_amp = torchaudio.functional.DB_to_amplitude(mel_spec_scaled, ref=hp.ampl_ref, power=hp.ampl_power)  
  return mel_spec_amp

def inverse_mel_spec_to_wav(mel_spec): # Expects [Freq, Time]
  mel_spec_on_device = mel_spec.to(DEVICE)
  power_mel_spec = db_to_power_mel_spec(mel_spec_on_device) # This is amplitude
  spectrogram = mel_inverse_transform(power_mel_spec) # Amplitude mel to linear amplitude
  pseudo_wav = griffnlim_transform(spectrogram) # Linear amplitude to wav
  return pseudo_wav

def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.BoolTensor:
    ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
    range_tensor = ones.cumsum(dim=1)
    return sequence_lengths.unsqueeze(1) >= range_tensor
    
# --- TransformerTTS Model Architecture (VERBATIM from your FastAPI code) ---
class EncoderBlock(nn.Module): # VERBATIM
    def __init__(self):
        super(EncoderBlock, self).__init__()
        self.norm_1 = nn.LayerNorm(hp.embedding_size)
        self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
        self.dropout_1 = torch.nn.Dropout(0.1)
        self.norm_2 = nn.LayerNorm(hp.embedding_size)
        self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
        self.dropout_2 = torch.nn.Dropout(0.1)
        self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
        self.dropout_3 = torch.nn.Dropout(0.1)
    def forward(self, x, attn_mask=None, key_padding_mask=None):
        x_out = self.norm_1(x); x_out, _ = self.attn(x_out, x_out, x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
        x_out = self.dropout_1(x_out); x = x + x_out    
        x_out = self.norm_2(x) ; x_out = self.linear_1(x_out); x_out = F.relu(x_out); x_out = self.dropout_2(x_out)
        x_out = self.linear_2(x_out); x_out = self.dropout_3(x_out)
        x = x + x_out; return x

class DecoderBlock(nn.Module): # VERBATIM
    def __init__(self):
        super(DecoderBlock, self).__init__()
        self.norm_1 = nn.LayerNorm(hp.embedding_size)
        self.self_attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
        self.dropout_1 = torch.nn.Dropout(0.1)
        self.norm_2 = nn.LayerNorm(hp.embedding_size)
        self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)    
        self.dropout_2 = torch.nn.Dropout(0.1)
        self.norm_3 = nn.LayerNorm(hp.embedding_size)
        self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
        self.dropout_3 = torch.nn.Dropout(0.1)
        self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
        self.dropout_4 = torch.nn.Dropout(0.1)
    def forward(self, x, memory, x_attn_mask=None, x_key_padding_mask=None, memory_attn_mask=None, memory_key_padding_mask=None):
        x_out, _ = self.self_attn(x, x, x, attn_mask=x_attn_mask, key_padding_mask=x_key_padding_mask)
        x_out = self.dropout_1(x_out); x = self.norm_1(x + x_out)
        x_out, _ = self.attn(x, memory, memory, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask)
        x_out = self.dropout_2(x_out); x = self.norm_2(x + x_out)
        x_out = self.linear_1(x); x_out = F.relu(x_out); x_out = self.dropout_3(x_out)
        x_out = self.linear_2(x_out); x_out = self.dropout_4(x_out)
        x = self.norm_3(x + x_out); return x

class EncoderPreNet(nn.Module): # VERBATIM
    def __init__(self):
        super(EncoderPreNet, self).__init__()
        self.embedding = nn.Embedding(hp.text_num_embeddings, hp.encoder_embedding_size)
        self.linear_1 = nn.Linear(hp.encoder_embedding_size, hp.encoder_embedding_size)
        self.linear_2 = nn.Linear(hp.encoder_embedding_size, hp.embedding_size)
        self.conv_1 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
        self.bn_1 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_1 = nn.Dropout(0.5)
        self.conv_2 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
        self.bn_2 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_2 = nn.Dropout(0.5)
        self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
        self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_3 = nn.Dropout(0.5)    
    def forward(self, text):
        x = self.embedding(text); x = self.linear_1(x); x = x.transpose(2,1) 
        x = self.conv_1(x); x = self.bn_1(x); x = F.relu(x); x = self.dropout_1(x)
        x = self.conv_2(x); x = self.bn_2(x); x = F.relu(x); x = self.dropout_2(x)
        x = self.conv_3(x); x = self.bn_3(x); x = F.relu(x); x = self.dropout_3(x)
        x = x.transpose(1,2); x = self.linear_2(x); return x

class PostNet(nn.Module): # VERBATIM
    def __init__(self):
        super(PostNet, self).__init__()
        self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
        self.bn_1 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_1 = nn.Dropout(0.5)
        self.conv_2 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
        self.bn_2 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_2 = nn.Dropout(0.5)
        self.conv_3 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
        self.bn_3 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_3 = nn.Dropout(0.5)
        self.conv_4 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
        self.bn_4 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_4 = nn.Dropout(0.5)
        self.conv_5 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
        self.bn_5 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_5 = nn.Dropout(0.5)
        self.conv_6 = nn.Conv1d(hp.postnet_embedding_size, hp.mel_freq, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
        self.bn_6 = nn.BatchNorm1d(hp.mel_freq); self.dropout_6 = nn.Dropout(0.5)
    def forward(self, x):
        x_orig = x; x = x.transpose(2,1)
        x = self.conv_1(x); x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
        x = self.conv_2(x); x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
        x = self.conv_3(x); x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
        x = self.conv_4(x); x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
        x = self.conv_5(x); x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
        x = self.conv_6(x); x = self.bn_6(x); x = self.dropout_6(x)
        x = x.transpose(1,2); return x # Original postnet in repo is residual, added in TransformerTTS.forward

class DecoderPreNet(nn.Module): # VERBATIM
    def __init__(self):
        super(DecoderPreNet, self).__init__()
        self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
        self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size)
    def forward(self, x):
        x = self.linear_1(x); x = F.relu(x)
        x = F.dropout(x, p=0.5, training=True) # Dropout always on
        x = self.linear_2(x); x = F.relu(x)    
        x = F.dropout(x, p=0.5, training=True) # Dropout always on
        return x    

class TransformerTTS(nn.Module): # VERBATIM (init had device=DEVICE, now model is moved after init)
    def __init__(self): # Removed device=DEVICE from here
        super(TransformerTTS, self).__init__()
        self.encoder_prenet = EncoderPreNet()
        self.decoder_prenet = DecoderPreNet()
        self.postnet = PostNet()
        self.pos_encoding = nn.Embedding(hp.max_mel_time, hp.embedding_size)
        self.encoder_block_1 = EncoderBlock(); self.encoder_block_2 = EncoderBlock(); self.encoder_block_3 = EncoderBlock()
        self.decoder_block_1 = DecoderBlock(); self.decoder_block_2 = DecoderBlock(); self.decoder_block_3 = DecoderBlock()
        self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq) 
        self.linear_2 = nn.Linear(hp.embedding_size, 1)
        self.norm_memory = nn.LayerNorm(hp.embedding_size)
        # Mask attributes will be set in forward pass, as per your code
        self.src_key_padding_mask = None; self.src_mask = None
        self.tgt_key_padding_mask = None; self.tgt_mask = None; self.memory_mask = None

    def forward(self, text, text_len, mel, mel_len): # VERBATIM
        N = text.shape[0]; S_text_in = text.shape[1]; TIME_mel_in = mel.shape[1]
        current_device = text.device

        self.src_key_padding_mask = torch.zeros((N, S_text_in), device=current_device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S_text_in), float("-inf"))
        self.src_mask = torch.zeros((S_text_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((S_text_in, S_text_in), True, dtype=torch.bool, device=current_device), diagonal=1), float("-inf"))
        self.tgt_key_padding_mask = torch.zeros((N, TIME_mel_in), device=current_device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME_mel_in), float("-inf"))
        self.tgt_mask = torch.zeros((TIME_mel_in, TIME_mel_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, TIME_mel_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))
        self.memory_mask = torch.zeros((TIME_mel_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, S_text_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))    
        
        text_x = self.encoder_prenet(text) 
        pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
        S_text_processed = text_x.shape[1]; text_x = text_x + pos_codes[:S_text_processed] # Use actual S after prenet
        
        text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
        text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)    
        text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
        text_x = self.norm_memory(text_x)
        
        mel_x = self.decoder_prenet(mel); 
        TIME_mel_processed = mel_x.shape[1]; mel_x = mel_x + pos_codes[:TIME_mel_processed] # Use actual T after prenet

        mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
        mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
        mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
        
        mel_linear = self.linear_1(mel_x)
        postnet_residual_out = self.postnet(mel_linear) # PostNet output
        mel_postnet = mel_linear + postnet_residual_out # Add residual
        stop_token = self.linear_2(mel_x) # (N, TIME, 1)

        # Masking output based on padding
        # self.tgt_key_padding_mask is -inf for padded, 0 for unpadded.
        # .ne(0) makes it True for padded, False for unpadded. This is correct for masked_fill.
        bool_mel_padding_mask = self.tgt_key_padding_mask.ne(0) 
        
        mel_linear = mel_linear.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_linear), 0)
        mel_postnet = mel_postnet.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_postnet), 0)
        stop_token = stop_token.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(stop_token), 1e3).squeeze(2)
        return mel_postnet, mel_linear, stop_token 

    @torch.no_grad() # VERBATIM from your FastAPI code (with .item() fix)
    def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=False): # with_tqdm was False in pipeline call
        self.eval(); self.train(False) # As per your original
        model_device = next(self.parameters()).device
        
        text_on_device = text.to(model_device)
        text_lengths = torch.tensor([text_on_device.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure text_lengths is also 2D [1,1] or [1]
        
        N = 1
        SOS = torch.zeros((N, 1, hp.mel_freq), device=model_device)
        mel_padded = SOS
        mel_lengths = torch.tensor([1],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure mel_lengths is also 2D [1,1] or [1]
        
        stop_token_outputs = torch.FloatTensor([]).to(model_device) # text.device might be CPU if text wasn't on device
        
        # Use local tqdm to avoid conflict if tqdm is imported elsewhere
        from tqdm import tqdm as tqdm_local 
        iters = tqdm_local(range(max_length), desc="TTS Inference") if with_tqdm else range(max_length)
        
        final_mel_postnet_output = SOS # To store the output from the last forward pass

        for _ in iters:
            # mel_postnet is (N, T_current_input_len, Freq)
            mel_postnet, mel_linear, stop_token = self(text_on_device, text_lengths, mel_padded, mel_lengths)
            final_mel_postnet_output = mel_postnet # This is the full sequence predicted in this step

            # Append last frame of mel_postnet for next input step
            mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
            mel_lengths = torch.tensor([mel_padded.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device)
            
            # stop_token is (N, T_current_input_len)
            # Check stop condition for the last frame of the input sequence
            if (torch.sigmoid(stop_token[:, -1].squeeze()) > stop_token_threshold).item():      
                break
            else:
                # stop_token[:, -1:] is (N, 1)
                stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
        
        # final_mel_postnet_output contains SOS. Strip it.
        if final_mel_postnet_output.shape[1] > 1: # If more than just SOS frame
            mel_to_return = final_mel_postnet_output[:, 1:, :] 
        else: # Only SOS was processed, or nothing
            mel_to_return = torch.empty((N, 0, hp.mel_freq), device=model_device)
            if mel_to_return.shape[1] == 0: # ensure stop_token_outputs is also empty
                 stop_token_outputs = torch.empty_like(stop_token_outputs[:,:0])


        return mel_to_return, stop_token_outputs
        
# --- Part 3: Model Loading (from Hugging Face Hub - VERBATIM from your FastAPI code) ---
TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
ASR_HUB_ID       = "MoHamdyy/whisper-stt-model"
MARIAN_HUB_ID    = "MoHamdyy/marian-ar-en-translation"

print("Loading models from Hugging Face Hub to device:", DEVICE)
TTS_MODEL = None; stt_processor = None; stt_model = None; mt_tokenizer = None; mt_model = None

try:
    print("Loading TTS model...")
    tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
    state = torch.load(tts_model_path, map_location=DEVICE)
    TTS_MODEL = TransformerTTS().to(DEVICE) # Create instance then move to DEVICE
    if "model" in state: TTS_MODEL.load_state_dict(state["model"])
    elif "state_dict" in state: TTS_MODEL.load_state_dict(state["state_dict"])
    else: TTS_MODEL.load_state_dict(state)
    TTS_MODEL.eval()
    print("TTS model loaded successfully.")
except Exception as e: print(f"Error loading TTS model: {e}")

try:
    print("Loading STT (Whisper) model...")
    stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
    stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
    print("STT model loaded successfully.")
except Exception as e: print(f"Error loading STT model: {e}")

try:
    print("Loading TTT (MarianMT) model...")
    mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
    mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
    print("TTT model loaded successfully.")
except Exception as e: print(f"Error loading TTT model: {e}")

# --- Part 4: Full Pipeline Function (VERBATIM from your FastAPI code, adapted for Gradio output) ---
def full_speech_translation_pipeline_gradio(audio_input_path: str): # Renamed for clarity
    print(f"--- PIPELINE START: Processing {audio_input_path} ---")
    # Check if models are loaded
    if not all([stt_processor, stt_model, mt_tokenizer, mt_model, TTS_MODEL]):
        error_msg = "One or more models failed to load. Please check logs."
        print(error_msg)
        return error_msg, error_msg, (hp.sr, np.array([]).astype(np.float32))

    if audio_input_path is None: # Gradio provides a path for uploaded/recorded audio
        msg = "Error: No audio input received by Gradio."
        print(msg)
        return msg, "", (hp.sr, np.array([]).astype(np.float32))
    
    if not os.path.exists(audio_input_path):
        # This case might happen if Gradio passes a temp path that gets cleaned up too quickly,
        # or if there's an issue with how Gradio handles file paths.
        # For Gradio `type="filepath"`, the path should be valid.
        msg = f"Error: Audio file path provided by Gradio does not exist: {audio_input_path}"
        print(msg)
        return msg, "", (hp.sr, np.array([]).astype(np.float32))


    # STT Stage
    arabic_transcript = "STT Error: Processing failed."
    try:
        print("STT: Loading and resampling audio...")
        wav, sr = torchaudio.load(audio_input_path)
        if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
        target_sr_stt = stt_processor.feature_extractor.sampling_rate
        if sr != target_sr_stt: wav = torchaudio.transforms.Resample(sr, target_sr_stt)(wav)
        audio_array_stt = wav.squeeze().cpu().numpy()
        
        print("STT: Extracting features and transcribing...")
        inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
        forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
        with torch.no_grad():
            generated_ids = stt_model.generate(inputs, forced_decoder_ids=forced_ids, max_length=448)
        arabic_transcript = stt_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
        print(f"STT Output: {arabic_transcript}")
        if not arabic_transcript: arabic_transcript = "(STT: No speech detected or empty transcript)"
    except Exception as e:
        print(f"STT Error: {e}"); import traceback; traceback.print_exc()
        arabic_transcript = f"STT Error: {e}"


    # TTT Stage
    english_translation = "TTT Error: Processing failed."
    tts_status_message = "" # For appending TTS status to English text
    if arabic_transcript and not arabic_transcript.startswith("STT Error") and not arabic_transcript.startswith("(STT:"):
        try:
            print("TTT: Translating to English...")
            batch = mt_tokenizer(arabic_transcript, return_tensors="pt", padding=True).to(DEVICE)
            with torch.no_grad():
                translated_ids = mt_model.generate(**batch, max_length=512)
            english_translation = mt_tokenizer.batch_decode(translated_ids, skip_special_tokens=True)[0].strip()
            print(f"TTT Output: {english_translation}")
            if not english_translation: english_translation = "(TTT: Empty translation)"
        except Exception as e:
            print(f"TTT Error: {e}"); import traceback; traceback.print_exc()
            english_translation = f"TTT Error: {e}"
    elif arabic_transcript.startswith("STT Error") or arabic_transcript.startswith("(STT:"):
        english_translation = "(Skipped TTT due to STT issue)"
        print(english_translation)
    else: # Should not happen if STT produces some output
        english_translation = "(Skipped TTT: Unknown STT state)"


    # TTS Stage
    synthesized_audio_np = np.array([]).astype(np.float32)
    if english_translation and not english_translation.startswith("TTT Error") and not english_translation.startswith("(Skipped") and not english_translation.startswith("(TTT:"):
        try:
            print("TTS: Synthesizing English speech...")
            sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE) # Ensure input is on TTS_MODEL's device
            
            # Make sure TTS_MODEL is on the correct device before inference
            TTS_MODEL.to(DEVICE) # Redundant if already done, but safe
            TTS_MODEL.eval()     # Ensure eval mode

            generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-20, stop_token_threshold=0.5, with_tqdm=False)
            
            print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
            if generated_mel is not None and generated_mel.numel() > 0 and generated_mel.shape[1] > 0: # Check if time dimension has frames
                mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1) # [F, T]
                audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
                synthesized_audio_np = audio_tensor.cpu().numpy()
                print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
            else:
                tts_status_message = "(TTS Error: Empty mel generated)"
                print(tts_status_message)
        except Exception as e:
            print(f"TTS Error: {e}"); import traceback; traceback.print_exc()
            tts_status_message = f"(TTS Error: {e})"
    elif english_translation.startswith("TTT Error") or english_translation.startswith("(Skipped") or english_translation.startswith("(TTT:"):
        tts_status_message = "(Skipped TTS due to TTT/Input issue)"
    else: # Should not happen if TTT produces some output
        tts_status_message = "(Skipped TTS: Unknown TTT state)"
        
    print(f"--- PIPELINE END ---")
    # Combine English translation with any TTS status message
    final_english_display = english_translation
    if tts_status_message:
        final_english_display += f" {tts_status_message}"
        
    return arabic_transcript, final_english_display.strip(), (hp.sr, synthesized_audio_np)

# --- Part 5: Gradio Interface ---
print("Setting up Gradio interface...")
demo = gr.Interface(
    fn=full_speech_translation_pipeline_gradio,
    inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Arabic Speech Input"),
    outputs=[
        gr.Textbox(label="Arabic Transcript (STT)"),
        gr.Textbox(label="English Translation & TTS Status"),
        gr.Audio(label="Synthesized English Speech (TTS)", type="numpy") # type="numpy" expects (sr, data)
    ],
    title="Arabic Speech-to-Text -> Translation -> English Text-to-Speech",
    description="Upload an Arabic audio file or record from your microphone. The system will transcribe it to Arabic, translate it to English, and then synthesize the English text into audible speech.",
    allow_flagging="never",
    examples=[["/kaggle/input/testtt/test_audio.ogg"]] if os.path.exists("/kaggle/input/testtt/test_audio.ogg") else None # Optional example
)
demo.launch(debug=True)
if __name__ == "__main__":
    print("Launching Gradio app...")
    # When running on Hugging Face Spaces, HF handles the launch.
    # For local testing, you might need a specific host/port.
    # HF Spaces will look for a `demo.launch()` or `iface.launch()`
    # demo.launch(debug=True) # debug=True for more detailed Gradio logs