Spaces:
Running
Running
import gradio as gr | |
import whisper | |
import torch | |
import os | |
from pydub import AudioSegment | |
from transformers import pipeline | |
# Mapping of model names to Whisper model sizes | |
MODELS = { | |
"Tiny (Fastest)": "tiny", | |
"Base (Faster)": "base", | |
"Small (Balanced)": "small", | |
"Medium (Accurate)": "medium", | |
"Large (Most Accurate)": "large" | |
} | |
# Fine-tuned models for specific languages | |
FINE_TUNED_MODELS = { | |
"Tamil": { | |
"model": "vasista22/whisper-tamil-medium", | |
"language": "ta" | |
}, | |
# Add more fine-tuned models for other languages here | |
} | |
# Mapping of full language names to language codes | |
LANGUAGE_NAME_TO_CODE = { | |
"Auto Detect": "Auto Detect", | |
"English": "en", | |
"Chinese": "zh", | |
"German": "de", | |
"Spanish": "es", | |
"Russian": "ru", | |
"Korean": "ko", | |
"French": "fr", | |
"Japanese": "ja", | |
"Portuguese": "pt", | |
"Turkish": "tr", | |
"Polish": "pl", | |
"Catalan": "ca", | |
"Dutch": "nl", | |
"Arabic": "ar", | |
"Swedish": "sv", | |
"Italian": "it", | |
"Indonesian": "id", | |
"Hindi": "hi", | |
"Finnish": "fi", | |
"Vietnamese": "vi", | |
"Hebrew": "he", | |
"Ukrainian": "uk", | |
"Greek": "el", | |
"Malay": "ms", | |
"Czech": "cs", | |
"Romanian": "ro", | |
"Danish": "da", | |
"Hungarian": "hu", | |
"Tamil": "ta", | |
"Norwegian": "no", | |
"Thai": "th", | |
"Urdu": "ur", | |
"Croatian": "hr", | |
"Bulgarian": "bg", | |
"Lithuanian": "lt", | |
"Latin": "la", | |
"Maori": "mi", | |
"Malayalam": "ml", | |
"Welsh": "cy", | |
"Slovak": "sk", | |
"Telugu": "te", | |
"Persian": "fa", | |
"Latvian": "lv", | |
"Bengali": "bn", | |
"Serbian": "sr", | |
"Azerbaijani": "az", | |
"Slovenian": "sl", | |
"Kannada": "kn", | |
"Estonian": "et", | |
"Macedonian": "mk", | |
"Breton": "br", | |
"Basque": "eu", | |
"Icelandic": "is", | |
"Armenian": "hy", | |
"Nepali": "ne", | |
"Mongolian": "mn", | |
"Bosnian": "bs", | |
"Kazakh": "kk", | |
"Albanian": "sq", | |
"Swahili": "sw", | |
"Galician": "gl", | |
"Marathi": "mr", | |
"Punjabi": "pa", | |
"Sinhala": "si", # Sinhala support | |
"Khmer": "km", | |
"Shona": "sn", | |
"Yoruba": "yo", | |
"Somali": "so", | |
"Afrikaans": "af", | |
"Occitan": "oc", | |
"Georgian": "ka", | |
"Belarusian": "be", | |
"Tajik": "tg", | |
"Sindhi": "sd", | |
"Gujarati": "gu", | |
"Amharic": "am", | |
"Yiddish": "yi", | |
"Lao": "lo", | |
"Uzbek": "uz", | |
"Faroese": "fo", | |
"Haitian Creole": "ht", | |
"Pashto": "ps", | |
"Turkmen": "tk", | |
"Nynorsk": "nn", | |
"Maltese": "mt", | |
"Sanskrit": "sa", | |
"Luxembourgish": "lb", | |
"Burmese": "my", | |
"Tibetan": "bo", | |
"Tagalog": "tl", | |
"Malagasy": "mg", | |
"Assamese": "as", | |
"Tatar": "tt", | |
"Hawaiian": "haw", | |
"Lingala": "ln", | |
"Hausa": "ha", | |
"Bashkir": "ba", | |
"Javanese": "jw", | |
"Sundanese": "su", | |
} | |
def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"): | |
"""Transcribe the audio file.""" | |
# Convert audio to 16kHz mono for better compatibility | |
audio = AudioSegment.from_file(audio_file) | |
audio = audio.set_frame_rate(16000).set_channels(1) | |
processed_audio_path = "processed_audio.wav" | |
audio.export(processed_audio_path, format="wav") | |
# Load the appropriate model | |
if language in FINE_TUNED_MODELS: | |
# Use the fine-tuned Whisper model for the selected language | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
transcribe = pipeline( | |
task="automatic-speech-recognition", | |
model=FINE_TUNED_MODELS[language]["model"], | |
chunk_length_s=30, | |
device=device | |
) | |
transcribe.model.config.forced_decoder_ids = transcribe.tokenizer.get_decoder_prompt_ids( | |
language=FINE_TUNED_MODELS[language]["language"], | |
task="transcribe" | |
) | |
result = transcribe(processed_audio_path) | |
transcription = result["text"] | |
detected_language = language | |
else: | |
# Use the selected Whisper model | |
model = whisper.load_model(MODELS[model_size]) | |
# Transcribe the audio | |
if language == "Auto Detect": | |
result = model.transcribe(processed_audio_path, fp16=False) # Auto-detect language | |
detected_language = result.get("language", "unknown") | |
else: | |
language_code = LANGUAGE_NAME_TO_CODE.get(language, "en") # Default to English if not found | |
result = model.transcribe(processed_audio_path, language=language_code, fp16=False) | |
detected_language = language_code | |
transcription = result["text"] | |
# Clean up processed audio file | |
os.remove(processed_audio_path) | |
# Return transcription and detected language | |
return f"Detected Language: {detected_language}\n\nTranscription:\n{transcription}" | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Audio Transcription with Fine-Tuned Models") | |
with gr.Tab("Transcribe Audio"): | |
gr.Markdown("Upload an audio file, select a language (or choose 'Auto Detect'), and choose a model for transcription.") | |
transcribe_audio_input = gr.Audio(type="filepath", label="Upload Audio File") | |
language_dropdown = gr.Dropdown( | |
choices=list(LANGUAGE_NAME_TO_CODE.keys()), # Full language names | |
label="Select Language", | |
value="Auto Detect" | |
) | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), # Model options | |
label="Select Model", | |
value="Base (Faster)", # Default to "Base" model | |
interactive=True # Allow model selection by default | |
) | |
transcribe_output = gr.Textbox(label="Transcription and Detected Language") | |
transcribe_button = gr.Button("Transcribe Audio") | |
# Update model dropdown based on language selection | |
def update_model_dropdown(language): | |
if language in FINE_TUNED_MODELS: | |
return gr.Dropdown(interactive=False, value=f"Fine-Tuned {language} Model") | |
else: | |
return gr.Dropdown(choices=list(MODELS.keys()), interactive=True, value="Base (Faster)") | |
language_dropdown.change(update_model_dropdown, inputs=language_dropdown, outputs=model_dropdown) | |
# Link button to function | |
transcribe_button.click(transcribe_audio, inputs=[transcribe_audio_input, language_dropdown, model_dropdown], outputs=transcribe_output) | |
# Launch the Gradio interface | |
demo.launch() |