import os import re import time import random import numpy as np import pandas as pd # Keep if hp or other parts use it, though not directly in pipeline import math # Keep if hp or other parts use it # import shutil # Not needed for Gradio file handling # import base64 # Not needed for Gradio audio output # Torch and Audio import torch import torch.nn as nn # import torch.optim as optim # Not needed for inference # from torch.utils.data import Dataset, DataLoader # Not needed for inference import torch.nn.functional as F import torchaudio # import librosa # Not strictly needed if not plotting in Gradio # import librosa.display # Not strictly needed if not plotting in Gradio # Text and Audio Processing from unidecode import unidecode from inflect import engine as inflect_engine_tts # Renamed to avoid conflict # import pydub # Not needed for Gradio audio output # import soundfile as sf # Gradio handles audio output directly # Transformers from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, MarianTokenizer, MarianMTModel, ) # Gradio import gradio as gr from huggingface_hub import hf_hub_download # For downloading models # --- Configuration & Device --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {DEVICE}") torch.manual_seed(42) np.random.seed(42) random.seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed_all(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # --- Hyperparams Class (VERBATIM from your notebook) --- class Hyperparams: seed = 42 csv_path = "path/to/metadata.csv" # Not used directly wav_path = "path/to/wavs" # Not used directly symbols = [ 'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü', '’', '“', '”' ] sr = 22050 n_fft = 2048 n_stft = int((n_fft//2) + 1) hop_length = int(n_fft/8.0) win_length = int(n_fft/2.0) mel_freq = 128 max_mel_time = 1024 power = 2.0 # For spec_transform if used, not directly by inverse_mel text_num_embeddings = 2*len(symbols) embedding_size = 256 encoder_embedding_size = 512 dim_feedforward = 1024 postnet_embedding_size = 1024 encoder_kernel_size = 3 postnet_kernel_size = 5 ampl_multiplier = 10.0 # For pow_to_db_mel_spec ampl_amin = 1e-10 # For pow_to_db_mel_spec db_multiplier = 1.0 # For pow_to_db_mel_spec ampl_ref = 1.0 # For db_to_power_mel_spec ampl_power = 1.0 # For db_to_power_mel_spec max_db = 100 # For pow_to_db_mel_spec scale_db = 10 # For pow_to_db_mel_spec & db_to_power_mel_spec hp = Hyperparams() # --- TTS Text & Audio Processing (VERBATIM from your notebook) --- symbol_to_id = {s: i for i, s in enumerate(hp.symbols)} def text_to_seq(text): text = text.lower() seq = [] for s in text: _id = symbol_to_id.get(s, None) if _id is not None: seq.append(_id) seq.append(symbol_to_id["EOS"]) return torch.IntTensor(seq) mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE) griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=1.0).to(DEVICE) # Explicit power=1.0 for magnitude def db_to_power_mel_spec(mel_spec): mel_spec_scaled = mel_spec * hp.scale_db # Corrected: use a different variable name mel_spec_amp = torchaudio.functional.DB_to_amplitude(mel_spec_scaled, ref=hp.ampl_ref, power=hp.ampl_power) return mel_spec_amp def inverse_mel_spec_to_wav(mel_spec): # Expects [Freq, Time] mel_spec_on_device = mel_spec.to(DEVICE) power_mel_spec = db_to_power_mel_spec(mel_spec_on_device) # This is amplitude spectrogram = mel_inverse_transform(power_mel_spec) # Amplitude mel to linear amplitude pseudo_wav = griffnlim_transform(spectrogram) # Linear amplitude to wav return pseudo_wav def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.BoolTensor: ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length) range_tensor = ones.cumsum(dim=1) return sequence_lengths.unsqueeze(1) >= range_tensor # --- TransformerTTS Model Architecture (VERBATIM from your FastAPI code) --- class EncoderBlock(nn.Module): # VERBATIM def __init__(self): super(EncoderBlock, self).__init__() self.norm_1 = nn.LayerNorm(hp.embedding_size) self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True) self.dropout_1 = torch.nn.Dropout(0.1) self.norm_2 = nn.LayerNorm(hp.embedding_size) self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward) self.dropout_2 = torch.nn.Dropout(0.1) self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size) self.dropout_3 = torch.nn.Dropout(0.1) def forward(self, x, attn_mask=None, key_padding_mask=None): x_out = self.norm_1(x); x_out, _ = self.attn(x_out, x_out, x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask) x_out = self.dropout_1(x_out); x = x + x_out x_out = self.norm_2(x) ; x_out = self.linear_1(x_out); x_out = F.relu(x_out); x_out = self.dropout_2(x_out) x_out = self.linear_2(x_out); x_out = self.dropout_3(x_out) x = x + x_out; return x class DecoderBlock(nn.Module): # VERBATIM def __init__(self): super(DecoderBlock, self).__init__() self.norm_1 = nn.LayerNorm(hp.embedding_size) self.self_attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True) self.dropout_1 = torch.nn.Dropout(0.1) self.norm_2 = nn.LayerNorm(hp.embedding_size) self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True) self.dropout_2 = torch.nn.Dropout(0.1) self.norm_3 = nn.LayerNorm(hp.embedding_size) self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward) self.dropout_3 = torch.nn.Dropout(0.1) self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size) self.dropout_4 = torch.nn.Dropout(0.1) def forward(self, x, memory, x_attn_mask=None, x_key_padding_mask=None, memory_attn_mask=None, memory_key_padding_mask=None): x_out, _ = self.self_attn(x, x, x, attn_mask=x_attn_mask, key_padding_mask=x_key_padding_mask) x_out = self.dropout_1(x_out); x = self.norm_1(x + x_out) x_out, _ = self.attn(x, memory, memory, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask) x_out = self.dropout_2(x_out); x = self.norm_2(x + x_out) x_out = self.linear_1(x); x_out = F.relu(x_out); x_out = self.dropout_3(x_out) x_out = self.linear_2(x_out); x_out = self.dropout_4(x_out) x = self.norm_3(x + x_out); return x class EncoderPreNet(nn.Module): # VERBATIM def __init__(self): super(EncoderPreNet, self).__init__() self.embedding = nn.Embedding(hp.text_num_embeddings, hp.encoder_embedding_size) self.linear_1 = nn.Linear(hp.encoder_embedding_size, hp.encoder_embedding_size) self.linear_2 = nn.Linear(hp.encoder_embedding_size, hp.embedding_size) self.conv_1 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1) self.bn_1 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_1 = nn.Dropout(0.5) self.conv_2 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1) self.bn_2 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_2 = nn.Dropout(0.5) self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1) self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_3 = nn.Dropout(0.5) def forward(self, text): x = self.embedding(text); x = self.linear_1(x); x = x.transpose(2,1) x = self.conv_1(x); x = self.bn_1(x); x = F.relu(x); x = self.dropout_1(x) x = self.conv_2(x); x = self.bn_2(x); x = F.relu(x); x = self.dropout_2(x) x = self.conv_3(x); x = self.bn_3(x); x = F.relu(x); x = self.dropout_3(x) x = x.transpose(1,2); x = self.linear_2(x); return x class PostNet(nn.Module): # VERBATIM def __init__(self): super(PostNet, self).__init__() self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1) self.bn_1 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_1 = nn.Dropout(0.5) self.conv_2 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1) self.bn_2 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_2 = nn.Dropout(0.5) self.conv_3 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1) self.bn_3 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_3 = nn.Dropout(0.5) self.conv_4 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1) self.bn_4 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_4 = nn.Dropout(0.5) self.conv_5 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1) self.bn_5 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_5 = nn.Dropout(0.5) self.conv_6 = nn.Conv1d(hp.postnet_embedding_size, hp.mel_freq, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1) self.bn_6 = nn.BatchNorm1d(hp.mel_freq); self.dropout_6 = nn.Dropout(0.5) def forward(self, x): x_orig = x; x = x.transpose(2,1) x = self.conv_1(x); x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x) x = self.conv_2(x); x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x) x = self.conv_3(x); x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x) x = self.conv_4(x); x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x) x = self.conv_5(x); x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x) x = self.conv_6(x); x = self.bn_6(x); x = self.dropout_6(x) x = x.transpose(1,2); return x # Original postnet in repo is residual, added in TransformerTTS.forward class DecoderPreNet(nn.Module): # VERBATIM def __init__(self): super(DecoderPreNet, self).__init__() self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size) self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size) def forward(self, x): x = self.linear_1(x); x = F.relu(x) x = F.dropout(x, p=0.5, training=True) # Dropout always on x = self.linear_2(x); x = F.relu(x) x = F.dropout(x, p=0.5, training=True) # Dropout always on return x class TransformerTTS(nn.Module): # VERBATIM (init had device=DEVICE, now model is moved after init) def __init__(self): # Removed device=DEVICE from here super(TransformerTTS, self).__init__() self.encoder_prenet = EncoderPreNet() self.decoder_prenet = DecoderPreNet() self.postnet = PostNet() self.pos_encoding = nn.Embedding(hp.max_mel_time, hp.embedding_size) self.encoder_block_1 = EncoderBlock(); self.encoder_block_2 = EncoderBlock(); self.encoder_block_3 = EncoderBlock() self.decoder_block_1 = DecoderBlock(); self.decoder_block_2 = DecoderBlock(); self.decoder_block_3 = DecoderBlock() self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq) self.linear_2 = nn.Linear(hp.embedding_size, 1) self.norm_memory = nn.LayerNorm(hp.embedding_size) # Mask attributes will be set in forward pass, as per your code self.src_key_padding_mask = None; self.src_mask = None self.tgt_key_padding_mask = None; self.tgt_mask = None; self.memory_mask = None def forward(self, text, text_len, mel, mel_len): # VERBATIM N = text.shape[0]; S_text_in = text.shape[1]; TIME_mel_in = mel.shape[1] current_device = text.device self.src_key_padding_mask = torch.zeros((N, S_text_in), device=current_device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S_text_in), float("-inf")) self.src_mask = torch.zeros((S_text_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((S_text_in, S_text_in), True, dtype=torch.bool, device=current_device), diagonal=1), float("-inf")) self.tgt_key_padding_mask = torch.zeros((N, TIME_mel_in), device=current_device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME_mel_in), float("-inf")) self.tgt_mask = torch.zeros((TIME_mel_in, TIME_mel_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, TIME_mel_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf")) self.memory_mask = torch.zeros((TIME_mel_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, S_text_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf")) text_x = self.encoder_prenet(text) pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device)) S_text_processed = text_x.shape[1]; text_x = text_x + pos_codes[:S_text_processed] # Use actual S after prenet text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask) text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask) text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask) text_x = self.norm_memory(text_x) mel_x = self.decoder_prenet(mel); TIME_mel_processed = mel_x.shape[1]; mel_x = mel_x + pos_codes[:TIME_mel_processed] # Use actual T after prenet mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask) mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask) mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask) mel_linear = self.linear_1(mel_x) postnet_residual_out = self.postnet(mel_linear) # PostNet output mel_postnet = mel_linear + postnet_residual_out # Add residual stop_token = self.linear_2(mel_x) # (N, TIME, 1) # Masking output based on padding # self.tgt_key_padding_mask is -inf for padded, 0 for unpadded. # .ne(0) makes it True for padded, False for unpadded. This is correct for masked_fill. bool_mel_padding_mask = self.tgt_key_padding_mask.ne(0) mel_linear = mel_linear.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_linear), 0) mel_postnet = mel_postnet.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_postnet), 0) stop_token = stop_token.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(stop_token), 1e3).squeeze(2) return mel_postnet, mel_linear, stop_token @torch.no_grad() # VERBATIM from your FastAPI code (with .item() fix) def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=False): # with_tqdm was False in pipeline call self.eval(); self.train(False) # As per your original model_device = next(self.parameters()).device text_on_device = text.to(model_device) text_lengths = torch.tensor([text_on_device.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure text_lengths is also 2D [1,1] or [1] N = 1 SOS = torch.zeros((N, 1, hp.mel_freq), device=model_device) mel_padded = SOS mel_lengths = torch.tensor([1],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure mel_lengths is also 2D [1,1] or [1] stop_token_outputs = torch.FloatTensor([]).to(model_device) # text.device might be CPU if text wasn't on device # Use local tqdm to avoid conflict if tqdm is imported elsewhere from tqdm import tqdm as tqdm_local iters = tqdm_local(range(max_length), desc="TTS Inference") if with_tqdm else range(max_length) final_mel_postnet_output = SOS # To store the output from the last forward pass for _ in iters: # mel_postnet is (N, T_current_input_len, Freq) mel_postnet, mel_linear, stop_token = self(text_on_device, text_lengths, mel_padded, mel_lengths) final_mel_postnet_output = mel_postnet # This is the full sequence predicted in this step # Append last frame of mel_postnet for next input step mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1) mel_lengths = torch.tensor([mel_padded.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device) # stop_token is (N, T_current_input_len) # Check stop condition for the last frame of the input sequence if (torch.sigmoid(stop_token[:, -1].squeeze()) > stop_token_threshold).item(): break else: # stop_token[:, -1:] is (N, 1) stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1) # final_mel_postnet_output contains SOS. Strip it. if final_mel_postnet_output.shape[1] > 1: # If more than just SOS frame mel_to_return = final_mel_postnet_output[:, 1:, :] else: # Only SOS was processed, or nothing mel_to_return = torch.empty((N, 0, hp.mel_freq), device=model_device) if mel_to_return.shape[1] == 0: # ensure stop_token_outputs is also empty stop_token_outputs = torch.empty_like(stop_token_outputs[:,:0]) return mel_to_return, stop_token_outputs # --- Part 3: Model Loading (from Hugging Face Hub - VERBATIM from your FastAPI code) --- TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech" ASR_HUB_ID = "MoHamdyy/whisper-stt-model" MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation" print("Loading models from Hugging Face Hub to device:", DEVICE) TTS_MODEL = None; stt_processor = None; stt_model = None; mt_tokenizer = None; mt_model = None try: print("Loading TTS model...") tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt") state = torch.load(tts_model_path, map_location=DEVICE) TTS_MODEL = TransformerTTS().to(DEVICE) # Create instance then move to DEVICE if "model" in state: TTS_MODEL.load_state_dict(state["model"]) elif "state_dict" in state: TTS_MODEL.load_state_dict(state["state_dict"]) else: TTS_MODEL.load_state_dict(state) TTS_MODEL.eval() print("TTS model loaded successfully.") except Exception as e: print(f"Error loading TTS model: {e}") try: print("Loading STT (Whisper) model...") stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID) stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval() print("STT model loaded successfully.") except Exception as e: print(f"Error loading STT model: {e}") try: print("Loading TTT (MarianMT) model...") mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID) mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval() print("TTT model loaded successfully.") except Exception as e: print(f"Error loading TTT model: {e}") # --- Part 4: Full Pipeline Function (VERBATIM from your FastAPI code, adapted for Gradio output) --- def full_speech_translation_pipeline_gradio(audio_input_path: str): # Renamed for clarity print(f"--- PIPELINE START: Processing {audio_input_path} ---") # Check if models are loaded if not all([stt_processor, stt_model, mt_tokenizer, mt_model, TTS_MODEL]): error_msg = "One or more models failed to load. Please check logs." print(error_msg) return error_msg, error_msg, (hp.sr, np.array([]).astype(np.float32)) if audio_input_path is None: # Gradio provides a path for uploaded/recorded audio msg = "Error: No audio input received by Gradio." print(msg) return msg, "", (hp.sr, np.array([]).astype(np.float32)) if not os.path.exists(audio_input_path): # This case might happen if Gradio passes a temp path that gets cleaned up too quickly, # or if there's an issue with how Gradio handles file paths. # For Gradio `type="filepath"`, the path should be valid. msg = f"Error: Audio file path provided by Gradio does not exist: {audio_input_path}" print(msg) return msg, "", (hp.sr, np.array([]).astype(np.float32)) # STT Stage arabic_transcript = "STT Error: Processing failed." try: print("STT: Loading and resampling audio...") wav, sr = torchaudio.load(audio_input_path) if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) target_sr_stt = stt_processor.feature_extractor.sampling_rate if sr != target_sr_stt: wav = torchaudio.transforms.Resample(sr, target_sr_stt)(wav) audio_array_stt = wav.squeeze().cpu().numpy() print("STT: Extracting features and transcribing...") inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE) forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe") with torch.no_grad(): generated_ids = stt_model.generate(inputs, forced_decoder_ids=forced_ids, max_length=448) arabic_transcript = stt_processor.decode(generated_ids[0], skip_special_tokens=True).strip() print(f"STT Output: {arabic_transcript}") if not arabic_transcript: arabic_transcript = "(STT: No speech detected or empty transcript)" except Exception as e: print(f"STT Error: {e}"); import traceback; traceback.print_exc() arabic_transcript = f"STT Error: {e}" # TTT Stage english_translation = "TTT Error: Processing failed." tts_status_message = "" # For appending TTS status to English text if arabic_transcript and not arabic_transcript.startswith("STT Error") and not arabic_transcript.startswith("(STT:"): try: print("TTT: Translating to English...") batch = mt_tokenizer(arabic_transcript, return_tensors="pt", padding=True).to(DEVICE) with torch.no_grad(): translated_ids = mt_model.generate(**batch, max_length=512) english_translation = mt_tokenizer.batch_decode(translated_ids, skip_special_tokens=True)[0].strip() print(f"TTT Output: {english_translation}") if not english_translation: english_translation = "(TTT: Empty translation)" except Exception as e: print(f"TTT Error: {e}"); import traceback; traceback.print_exc() english_translation = f"TTT Error: {e}" elif arabic_transcript.startswith("STT Error") or arabic_transcript.startswith("(STT:"): english_translation = "(Skipped TTT due to STT issue)" print(english_translation) else: # Should not happen if STT produces some output english_translation = "(Skipped TTT: Unknown STT state)" # TTS Stage synthesized_audio_np = np.array([]).astype(np.float32) if english_translation and not english_translation.startswith("TTT Error") and not english_translation.startswith("(Skipped") and not english_translation.startswith("(TTT:"): try: print("TTS: Synthesizing English speech...") sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE) # Ensure input is on TTS_MODEL's device # Make sure TTS_MODEL is on the correct device before inference TTS_MODEL.to(DEVICE) # Redundant if already done, but safe TTS_MODEL.eval() # Ensure eval mode generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-20, stop_token_threshold=0.5, with_tqdm=False) print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}") if generated_mel is not None and generated_mel.numel() > 0 and generated_mel.shape[1] > 0: # Check if time dimension has frames mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1) # [F, T] audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder) synthesized_audio_np = audio_tensor.cpu().numpy() print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}") else: tts_status_message = "(TTS Error: Empty mel generated)" print(tts_status_message) except Exception as e: print(f"TTS Error: {e}"); import traceback; traceback.print_exc() tts_status_message = f"(TTS Error: {e})" elif english_translation.startswith("TTT Error") or english_translation.startswith("(Skipped") or english_translation.startswith("(TTT:"): tts_status_message = "(Skipped TTS due to TTT/Input issue)" else: # Should not happen if TTT produces some output tts_status_message = "(Skipped TTS: Unknown TTT state)" print(f"--- PIPELINE END ---") # Combine English translation with any TTS status message final_english_display = english_translation if tts_status_message: final_english_display += f" {tts_status_message}" return arabic_transcript, final_english_display.strip(), (hp.sr, synthesized_audio_np) # --- Part 5: Gradio Interface --- print("Setting up Gradio interface...") demo = gr.Interface( fn=full_speech_translation_pipeline_gradio, inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Arabic Speech Input"), outputs=[ gr.Textbox(label="Arabic Transcript (STT)"), gr.Textbox(label="English Translation & TTS Status"), gr.Audio(label="Synthesized English Speech (TTS)", type="numpy") # type="numpy" expects (sr, data) ], title="Arabic Speech-to-Text -> Translation -> English Text-to-Speech", description="Upload an Arabic audio file or record from your microphone. The system will transcribe it to Arabic, translate it to English, and then synthesize the English text into audible speech.", allow_flagging="never", examples=[["/kaggle/input/testtt/test_audio.ogg"]] if os.path.exists("/kaggle/input/testtt/test_audio.ogg") else None # Optional example ) demo.launch(debug=True) # if __name__ == "__main__": # print("Launching Gradio app...") # # When running on Hugging Face Spaces, HF handles the launch. # # For local testing, you might need a specific host/port. # # HF Spaces will look for a `demo.launch()` or `iface.launch()` # # demo.launch(debug=True) # debug=True for more detailed Gradio logs