MoHamdyy's picture
changed app to gradio
4a2c2a8
raw
history blame
27.8 kB
import os
import re
import time
import random
import numpy as np
import pandas as pd # Keep if hp or other parts use it, though not directly in pipeline
import math # Keep if hp or other parts use it
# import shutil # Not needed for Gradio file handling
# import base64 # Not needed for Gradio audio output
# Torch and Audio
import torch
import torch.nn as nn
# import torch.optim as optim # Not needed for inference
# from torch.utils.data import Dataset, DataLoader # Not needed for inference
import torch.nn.functional as F
import torchaudio
# import librosa # Not strictly needed if not plotting in Gradio
# import librosa.display # Not strictly needed if not plotting in Gradio
# Text and Audio Processing
from unidecode import unidecode
from inflect import engine as inflect_engine_tts # Renamed to avoid conflict
# import pydub # Not needed for Gradio audio output
# import soundfile as sf # Gradio handles audio output directly
# Transformers
from transformers import (
WhisperProcessor, WhisperForConditionalGeneration,
MarianTokenizer, MarianMTModel,
)
# Gradio
import gradio as gr
from huggingface_hub import hf_hub_download # For downloading models
# --- Configuration & Device ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# --- Hyperparams Class (VERBATIM from your notebook) ---
class Hyperparams:
seed = 42
csv_path = "path/to/metadata.csv" # Not used directly
wav_path = "path/to/wavs" # Not used directly
symbols = [
'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü',
'’', '“', '”'
]
sr = 22050
n_fft = 2048
n_stft = int((n_fft//2) + 1)
hop_length = int(n_fft/8.0)
win_length = int(n_fft/2.0)
mel_freq = 128
max_mel_time = 1024
power = 2.0 # For spec_transform if used, not directly by inverse_mel
text_num_embeddings = 2*len(symbols)
embedding_size = 256
encoder_embedding_size = 512
dim_feedforward = 1024
postnet_embedding_size = 1024
encoder_kernel_size = 3
postnet_kernel_size = 5
ampl_multiplier = 10.0 # For pow_to_db_mel_spec
ampl_amin = 1e-10 # For pow_to_db_mel_spec
db_multiplier = 1.0 # For pow_to_db_mel_spec
ampl_ref = 1.0 # For db_to_power_mel_spec
ampl_power = 1.0 # For db_to_power_mel_spec
max_db = 100 # For pow_to_db_mel_spec
scale_db = 10 # For pow_to_db_mel_spec & db_to_power_mel_spec
hp = Hyperparams()
# --- TTS Text & Audio Processing (VERBATIM from your notebook) ---
symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
def text_to_seq(text):
text = text.lower()
seq = []
for s in text:
_id = symbol_to_id.get(s, None)
if _id is not None:
seq.append(_id)
seq.append(symbol_to_id["EOS"])
return torch.IntTensor(seq)
mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=1.0).to(DEVICE) # Explicit power=1.0 for magnitude
def db_to_power_mel_spec(mel_spec):
mel_spec_scaled = mel_spec * hp.scale_db # Corrected: use a different variable name
mel_spec_amp = torchaudio.functional.DB_to_amplitude(mel_spec_scaled, ref=hp.ampl_ref, power=hp.ampl_power)
return mel_spec_amp
def inverse_mel_spec_to_wav(mel_spec): # Expects [Freq, Time]
mel_spec_on_device = mel_spec.to(DEVICE)
power_mel_spec = db_to_power_mel_spec(mel_spec_on_device) # This is amplitude
spectrogram = mel_inverse_transform(power_mel_spec) # Amplitude mel to linear amplitude
pseudo_wav = griffnlim_transform(spectrogram) # Linear amplitude to wav
return pseudo_wav
def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.BoolTensor:
ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
range_tensor = ones.cumsum(dim=1)
return sequence_lengths.unsqueeze(1) >= range_tensor
# --- TransformerTTS Model Architecture (VERBATIM from your FastAPI code) ---
class EncoderBlock(nn.Module): # VERBATIM
def __init__(self):
super(EncoderBlock, self).__init__()
self.norm_1 = nn.LayerNorm(hp.embedding_size)
self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
self.dropout_1 = torch.nn.Dropout(0.1)
self.norm_2 = nn.LayerNorm(hp.embedding_size)
self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
self.dropout_2 = torch.nn.Dropout(0.1)
self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
self.dropout_3 = torch.nn.Dropout(0.1)
def forward(self, x, attn_mask=None, key_padding_mask=None):
x_out = self.norm_1(x); x_out, _ = self.attn(x_out, x_out, x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
x_out = self.dropout_1(x_out); x = x + x_out
x_out = self.norm_2(x) ; x_out = self.linear_1(x_out); x_out = F.relu(x_out); x_out = self.dropout_2(x_out)
x_out = self.linear_2(x_out); x_out = self.dropout_3(x_out)
x = x + x_out; return x
class DecoderBlock(nn.Module): # VERBATIM
def __init__(self):
super(DecoderBlock, self).__init__()
self.norm_1 = nn.LayerNorm(hp.embedding_size)
self.self_attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
self.dropout_1 = torch.nn.Dropout(0.1)
self.norm_2 = nn.LayerNorm(hp.embedding_size)
self.attn = torch.nn.MultiheadAttention(hp.embedding_size, 4, dropout=0.1, batch_first=True)
self.dropout_2 = torch.nn.Dropout(0.1)
self.norm_3 = nn.LayerNorm(hp.embedding_size)
self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
self.dropout_3 = torch.nn.Dropout(0.1)
self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
self.dropout_4 = torch.nn.Dropout(0.1)
def forward(self, x, memory, x_attn_mask=None, x_key_padding_mask=None, memory_attn_mask=None, memory_key_padding_mask=None):
x_out, _ = self.self_attn(x, x, x, attn_mask=x_attn_mask, key_padding_mask=x_key_padding_mask)
x_out = self.dropout_1(x_out); x = self.norm_1(x + x_out)
x_out, _ = self.attn(x, memory, memory, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask)
x_out = self.dropout_2(x_out); x = self.norm_2(x + x_out)
x_out = self.linear_1(x); x_out = F.relu(x_out); x_out = self.dropout_3(x_out)
x_out = self.linear_2(x_out); x_out = self.dropout_4(x_out)
x = self.norm_3(x + x_out); return x
class EncoderPreNet(nn.Module): # VERBATIM
def __init__(self):
super(EncoderPreNet, self).__init__()
self.embedding = nn.Embedding(hp.text_num_embeddings, hp.encoder_embedding_size)
self.linear_1 = nn.Linear(hp.encoder_embedding_size, hp.encoder_embedding_size)
self.linear_2 = nn.Linear(hp.encoder_embedding_size, hp.embedding_size)
self.conv_1 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
self.bn_1 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_1 = nn.Dropout(0.5)
self.conv_2 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
self.bn_2 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_2 = nn.Dropout(0.5)
self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, hp.encoder_kernel_size, 1, int((hp.encoder_kernel_size-1)/2),1)
self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size); self.dropout_3 = nn.Dropout(0.5)
def forward(self, text):
x = self.embedding(text); x = self.linear_1(x); x = x.transpose(2,1)
x = self.conv_1(x); x = self.bn_1(x); x = F.relu(x); x = self.dropout_1(x)
x = self.conv_2(x); x = self.bn_2(x); x = F.relu(x); x = self.dropout_2(x)
x = self.conv_3(x); x = self.bn_3(x); x = F.relu(x); x = self.dropout_3(x)
x = x.transpose(1,2); x = self.linear_2(x); return x
class PostNet(nn.Module): # VERBATIM
def __init__(self):
super(PostNet, self).__init__()
self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
self.bn_1 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_1 = nn.Dropout(0.5)
self.conv_2 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
self.bn_2 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_2 = nn.Dropout(0.5)
self.conv_3 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
self.bn_3 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_3 = nn.Dropout(0.5)
self.conv_4 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
self.bn_4 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_4 = nn.Dropout(0.5)
self.conv_5 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
self.bn_5 = nn.BatchNorm1d(hp.postnet_embedding_size); self.dropout_5 = nn.Dropout(0.5)
self.conv_6 = nn.Conv1d(hp.postnet_embedding_size, hp.mel_freq, hp.postnet_kernel_size, 1, int((hp.postnet_kernel_size-1)/2),1)
self.bn_6 = nn.BatchNorm1d(hp.mel_freq); self.dropout_6 = nn.Dropout(0.5)
def forward(self, x):
x_orig = x; x = x.transpose(2,1)
x = self.conv_1(x); x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
x = self.conv_2(x); x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
x = self.conv_3(x); x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
x = self.conv_4(x); x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
x = self.conv_5(x); x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
x = self.conv_6(x); x = self.bn_6(x); x = self.dropout_6(x)
x = x.transpose(1,2); return x # Original postnet in repo is residual, added in TransformerTTS.forward
class DecoderPreNet(nn.Module): # VERBATIM
def __init__(self):
super(DecoderPreNet, self).__init__()
self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size)
def forward(self, x):
x = self.linear_1(x); x = F.relu(x)
x = F.dropout(x, p=0.5, training=True) # Dropout always on
x = self.linear_2(x); x = F.relu(x)
x = F.dropout(x, p=0.5, training=True) # Dropout always on
return x
class TransformerTTS(nn.Module): # VERBATIM (init had device=DEVICE, now model is moved after init)
def __init__(self): # Removed device=DEVICE from here
super(TransformerTTS, self).__init__()
self.encoder_prenet = EncoderPreNet()
self.decoder_prenet = DecoderPreNet()
self.postnet = PostNet()
self.pos_encoding = nn.Embedding(hp.max_mel_time, hp.embedding_size)
self.encoder_block_1 = EncoderBlock(); self.encoder_block_2 = EncoderBlock(); self.encoder_block_3 = EncoderBlock()
self.decoder_block_1 = DecoderBlock(); self.decoder_block_2 = DecoderBlock(); self.decoder_block_3 = DecoderBlock()
self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
self.linear_2 = nn.Linear(hp.embedding_size, 1)
self.norm_memory = nn.LayerNorm(hp.embedding_size)
# Mask attributes will be set in forward pass, as per your code
self.src_key_padding_mask = None; self.src_mask = None
self.tgt_key_padding_mask = None; self.tgt_mask = None; self.memory_mask = None
def forward(self, text, text_len, mel, mel_len): # VERBATIM
N = text.shape[0]; S_text_in = text.shape[1]; TIME_mel_in = mel.shape[1]
current_device = text.device
self.src_key_padding_mask = torch.zeros((N, S_text_in), device=current_device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S_text_in), float("-inf"))
self.src_mask = torch.zeros((S_text_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((S_text_in, S_text_in), True, dtype=torch.bool, device=current_device), diagonal=1), float("-inf"))
self.tgt_key_padding_mask = torch.zeros((N, TIME_mel_in), device=current_device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME_mel_in), float("-inf"))
self.tgt_mask = torch.zeros((TIME_mel_in, TIME_mel_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, TIME_mel_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))
self.memory_mask = torch.zeros((TIME_mel_in, S_text_in), device=current_device).masked_fill(torch.triu(torch.full((TIME_mel_in, S_text_in), True, device=current_device, dtype=torch.bool), diagonal=1), float("-inf"))
text_x = self.encoder_prenet(text)
pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time, device=current_device))
S_text_processed = text_x.shape[1]; text_x = text_x + pos_codes[:S_text_processed] # Use actual S after prenet
text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
text_x = self.norm_memory(text_x)
mel_x = self.decoder_prenet(mel);
TIME_mel_processed = mel_x.shape[1]; mel_x = mel_x + pos_codes[:TIME_mel_processed] # Use actual T after prenet
mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
mel_linear = self.linear_1(mel_x)
postnet_residual_out = self.postnet(mel_linear) # PostNet output
mel_postnet = mel_linear + postnet_residual_out # Add residual
stop_token = self.linear_2(mel_x) # (N, TIME, 1)
# Masking output based on padding
# self.tgt_key_padding_mask is -inf for padded, 0 for unpadded.
# .ne(0) makes it True for padded, False for unpadded. This is correct for masked_fill.
bool_mel_padding_mask = self.tgt_key_padding_mask.ne(0)
mel_linear = mel_linear.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_linear), 0)
mel_postnet = mel_postnet.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(mel_postnet), 0)
stop_token = stop_token.masked_fill(bool_mel_padding_mask.unsqueeze(-1).expand_as(stop_token), 1e3).squeeze(2)
return mel_postnet, mel_linear, stop_token
@torch.no_grad() # VERBATIM from your FastAPI code (with .item() fix)
def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=False): # with_tqdm was False in pipeline call
self.eval(); self.train(False) # As per your original
model_device = next(self.parameters()).device
text_on_device = text.to(model_device)
text_lengths = torch.tensor([text_on_device.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure text_lengths is also 2D [1,1] or [1]
N = 1
SOS = torch.zeros((N, 1, hp.mel_freq), device=model_device)
mel_padded = SOS
mel_lengths = torch.tensor([1],dtype=torch.long).unsqueeze(0).to(model_device) # Ensure mel_lengths is also 2D [1,1] or [1]
stop_token_outputs = torch.FloatTensor([]).to(model_device) # text.device might be CPU if text wasn't on device
# Use local tqdm to avoid conflict if tqdm is imported elsewhere
from tqdm import tqdm as tqdm_local
iters = tqdm_local(range(max_length), desc="TTS Inference") if with_tqdm else range(max_length)
final_mel_postnet_output = SOS # To store the output from the last forward pass
for _ in iters:
# mel_postnet is (N, T_current_input_len, Freq)
mel_postnet, mel_linear, stop_token = self(text_on_device, text_lengths, mel_padded, mel_lengths)
final_mel_postnet_output = mel_postnet # This is the full sequence predicted in this step
# Append last frame of mel_postnet for next input step
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
mel_lengths = torch.tensor([mel_padded.shape[1]],dtype=torch.long).unsqueeze(0).to(model_device)
# stop_token is (N, T_current_input_len)
# Check stop condition for the last frame of the input sequence
if (torch.sigmoid(stop_token[:, -1].squeeze()) > stop_token_threshold).item():
break
else:
# stop_token[:, -1:] is (N, 1)
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
# final_mel_postnet_output contains SOS. Strip it.
if final_mel_postnet_output.shape[1] > 1: # If more than just SOS frame
mel_to_return = final_mel_postnet_output[:, 1:, :]
else: # Only SOS was processed, or nothing
mel_to_return = torch.empty((N, 0, hp.mel_freq), device=model_device)
if mel_to_return.shape[1] == 0: # ensure stop_token_outputs is also empty
stop_token_outputs = torch.empty_like(stop_token_outputs[:,:0])
return mel_to_return, stop_token_outputs
# --- Part 3: Model Loading (from Hugging Face Hub - VERBATIM from your FastAPI code) ---
TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
print("Loading models from Hugging Face Hub to device:", DEVICE)
TTS_MODEL = None; stt_processor = None; stt_model = None; mt_tokenizer = None; mt_model = None
try:
print("Loading TTS model...")
tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
state = torch.load(tts_model_path, map_location=DEVICE)
TTS_MODEL = TransformerTTS().to(DEVICE) # Create instance then move to DEVICE
if "model" in state: TTS_MODEL.load_state_dict(state["model"])
elif "state_dict" in state: TTS_MODEL.load_state_dict(state["state_dict"])
else: TTS_MODEL.load_state_dict(state)
TTS_MODEL.eval()
print("TTS model loaded successfully.")
except Exception as e: print(f"Error loading TTS model: {e}")
try:
print("Loading STT (Whisper) model...")
stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
print("STT model loaded successfully.")
except Exception as e: print(f"Error loading STT model: {e}")
try:
print("Loading TTT (MarianMT) model...")
mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
print("TTT model loaded successfully.")
except Exception as e: print(f"Error loading TTT model: {e}")
# --- Part 4: Full Pipeline Function (VERBATIM from your FastAPI code, adapted for Gradio output) ---
def full_speech_translation_pipeline_gradio(audio_input_path: str): # Renamed for clarity
print(f"--- PIPELINE START: Processing {audio_input_path} ---")
# Check if models are loaded
if not all([stt_processor, stt_model, mt_tokenizer, mt_model, TTS_MODEL]):
error_msg = "One or more models failed to load. Please check logs."
print(error_msg)
return error_msg, error_msg, (hp.sr, np.array([]).astype(np.float32))
if audio_input_path is None: # Gradio provides a path for uploaded/recorded audio
msg = "Error: No audio input received by Gradio."
print(msg)
return msg, "", (hp.sr, np.array([]).astype(np.float32))
if not os.path.exists(audio_input_path):
# This case might happen if Gradio passes a temp path that gets cleaned up too quickly,
# or if there's an issue with how Gradio handles file paths.
# For Gradio `type="filepath"`, the path should be valid.
msg = f"Error: Audio file path provided by Gradio does not exist: {audio_input_path}"
print(msg)
return msg, "", (hp.sr, np.array([]).astype(np.float32))
# STT Stage
arabic_transcript = "STT Error: Processing failed."
try:
print("STT: Loading and resampling audio...")
wav, sr = torchaudio.load(audio_input_path)
if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
target_sr_stt = stt_processor.feature_extractor.sampling_rate
if sr != target_sr_stt: wav = torchaudio.transforms.Resample(sr, target_sr_stt)(wav)
audio_array_stt = wav.squeeze().cpu().numpy()
print("STT: Extracting features and transcribing...")
inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
with torch.no_grad():
generated_ids = stt_model.generate(inputs, forced_decoder_ids=forced_ids, max_length=448)
arabic_transcript = stt_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
print(f"STT Output: {arabic_transcript}")
if not arabic_transcript: arabic_transcript = "(STT: No speech detected or empty transcript)"
except Exception as e:
print(f"STT Error: {e}"); import traceback; traceback.print_exc()
arabic_transcript = f"STT Error: {e}"
# TTT Stage
english_translation = "TTT Error: Processing failed."
tts_status_message = "" # For appending TTS status to English text
if arabic_transcript and not arabic_transcript.startswith("STT Error") and not arabic_transcript.startswith("(STT:"):
try:
print("TTT: Translating to English...")
batch = mt_tokenizer(arabic_transcript, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
translated_ids = mt_model.generate(**batch, max_length=512)
english_translation = mt_tokenizer.batch_decode(translated_ids, skip_special_tokens=True)[0].strip()
print(f"TTT Output: {english_translation}")
if not english_translation: english_translation = "(TTT: Empty translation)"
except Exception as e:
print(f"TTT Error: {e}"); import traceback; traceback.print_exc()
english_translation = f"TTT Error: {e}"
elif arabic_transcript.startswith("STT Error") or arabic_transcript.startswith("(STT:"):
english_translation = "(Skipped TTT due to STT issue)"
print(english_translation)
else: # Should not happen if STT produces some output
english_translation = "(Skipped TTT: Unknown STT state)"
# TTS Stage
synthesized_audio_np = np.array([]).astype(np.float32)
if english_translation and not english_translation.startswith("TTT Error") and not english_translation.startswith("(Skipped") and not english_translation.startswith("(TTT:"):
try:
print("TTS: Synthesizing English speech...")
sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE) # Ensure input is on TTS_MODEL's device
# Make sure TTS_MODEL is on the correct device before inference
TTS_MODEL.to(DEVICE) # Redundant if already done, but safe
TTS_MODEL.eval() # Ensure eval mode
generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-20, stop_token_threshold=0.5, with_tqdm=False)
print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
if generated_mel is not None and generated_mel.numel() > 0 and generated_mel.shape[1] > 0: # Check if time dimension has frames
mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1) # [F, T]
audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
synthesized_audio_np = audio_tensor.cpu().numpy()
print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
else:
tts_status_message = "(TTS Error: Empty mel generated)"
print(tts_status_message)
except Exception as e:
print(f"TTS Error: {e}"); import traceback; traceback.print_exc()
tts_status_message = f"(TTS Error: {e})"
elif english_translation.startswith("TTT Error") or english_translation.startswith("(Skipped") or english_translation.startswith("(TTT:"):
tts_status_message = "(Skipped TTS due to TTT/Input issue)"
else: # Should not happen if TTT produces some output
tts_status_message = "(Skipped TTS: Unknown TTT state)"
print(f"--- PIPELINE END ---")
# Combine English translation with any TTS status message
final_english_display = english_translation
if tts_status_message:
final_english_display += f" {tts_status_message}"
return arabic_transcript, final_english_display.strip(), (hp.sr, synthesized_audio_np)
# --- Part 5: Gradio Interface ---
print("Setting up Gradio interface...")
demo = gr.Interface(
fn=full_speech_translation_pipeline_gradio,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Arabic Speech Input"),
outputs=[
gr.Textbox(label="Arabic Transcript (STT)"),
gr.Textbox(label="English Translation & TTS Status"),
gr.Audio(label="Synthesized English Speech (TTS)", type="numpy") # type="numpy" expects (sr, data)
],
title="Arabic Speech-to-Text -> Translation -> English Text-to-Speech",
description="Upload an Arabic audio file or record from your microphone. The system will transcribe it to Arabic, translate it to English, and then synthesize the English text into audible speech.",
allow_flagging="never",
examples=[["/kaggle/input/testtt/test_audio.ogg"]] if os.path.exists("/kaggle/input/testtt/test_audio.ogg") else None # Optional example
)
demo.launch(debug=True)
if __name__ == "__main__":
print("Launching Gradio app...")
# When running on Hugging Face Spaces, HF handles the launch.
# For local testing, you might need a specific host/port.
# HF Spaces will look for a `demo.launch()` or `iface.launch()`
# demo.launch(debug=True) # debug=True for more detailed Gradio logs