Spaces:
Running
Running
import gradio as gr | |
import whisper | |
import torch | |
import os | |
from pydub import AudioSegment, silence | |
from faster_whisper import WhisperModel | |
import numpy as np | |
from scipy.io import wavfile | |
from scipy.signal import correlate | |
import tempfile | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Mapping of model names to Whisper model sizes | |
MODELS = { | |
"Tiny (Fastest)": "tiny", | |
"Base (Faster)": "base", | |
"Small (Balanced)": "small", | |
"Medium (Accurate)": "medium", | |
"Large (Most Accurate)": "large", | |
"Faster Whisper Large v3": "Systran/faster-whisper-large-v3" # Renamed and set as default | |
} | |
# Mapping of full language names to language codes | |
LANGUAGE_NAME_TO_CODE = { | |
"Auto Detect": "Auto Detect", | |
"English": "en", | |
"Chinese": "zh", | |
"German": "de", | |
"Spanish": "es", | |
"Russian": "ru", | |
"Korean": "ko", | |
"French": "fr", | |
"Japanese": "ja", | |
"Portuguese": "pt", | |
"Turkish": "tr", | |
"Polish": "pl", | |
"Catalan": "ca", | |
"Dutch": "nl", | |
"Arabic": "ar", | |
"Swedish": "sv", | |
"Italian": "it", | |
"Indonesian": "id", | |
"Hindi": "hi", | |
"Finnish": "fi", | |
"Vietnamese": "vi", | |
"Hebrew": "he", | |
"Ukrainian": "uk", | |
"Greek": "el", | |
"Malay": "ms", | |
"Czech": "cs", | |
"Romanian": "ro", | |
"Danish": "da", | |
"Hungarian": "hu", | |
"Tamil": "ta", | |
"Norwegian": "no", | |
"Thai": "th", | |
"Urdu": "ur", | |
"Croatian": "hr", | |
"Bulgarian": "bg", | |
"Lithuanian": "lt", | |
"Latin": "la", | |
"Maori": "mi", | |
"Malayalam": "ml", | |
"Welsh": "cy", | |
"Slovak": "sk", | |
"Telugu": "te", | |
"Persian": "fa", | |
"Latvian": "lv", | |
"Bengali": "bn", | |
"Serbian": "sr", | |
"Azerbaijani": "az", | |
"Slovenian": "sl", | |
"Kannada": "kn", | |
"Estonian": "et", | |
"Macedonian": "mk", | |
"Breton": "br", | |
"Basque": "eu", | |
"Icelandic": "is", | |
"Armenian": "hy", | |
"Nepali": "ne", | |
"Mongolian": "mn", | |
"Bosnian": "bs", | |
"Kazakh": "kk", | |
"Albanian": "sq", | |
"Swahili": "sw", | |
"Galician": "gl", | |
"Marathi": "mr", | |
"Punjabi": "pa", | |
"Sinhala": "si", # Sinhala support | |
"Khmer": "km", | |
"Shona": "sn", | |
"Yoruba": "yo", | |
"Somali": "so", | |
"Afrikaans": "af", | |
"Occitan": "oc", | |
"Georgian": "ka", | |
"Belarusian": "be", | |
"Tajik": "tg", | |
"Sindhi": "sd", | |
"Gujarati": "gu", | |
"Amharic": "am", | |
"Yiddish": "yi", | |
"Lao": "lo", | |
"Uzbek": "uz", | |
"Faroese": "fo", | |
"Haitian Creole": "ht", | |
"Pashto": "ps", | |
"Turkmen": "tk", | |
"Nynorsk": "nn", | |
"Maltese": "mt", | |
"Sanskrit": "sa", | |
"Luxembourgish": "lb", | |
"Burmese": "my", | |
"Tibetan": "bo", | |
"Tagalog": "tl", | |
"Malagasy": "mg", | |
"Assamese": "as", | |
"Tatar": "tt", | |
"Hawaiian": "haw", | |
"Lingala": "ln", | |
"Hausa": "ha", | |
"Bashkir": "ba", | |
"Javanese": "jw", | |
"Sundanese": "su", | |
} | |
# Reverse mapping of language codes to full language names | |
CODE_TO_LANGUAGE_NAME = {v: k for k, v in LANGUAGE_NAME_TO_CODE.items()} | |
def convert_to_wav(audio_file): | |
"""Convert any audio file to WAV format.""" | |
audio = AudioSegment.from_file(audio_file) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: | |
wav_path = temp_wav.name | |
audio.export(wav_path, format="wav") | |
return wav_path | |
def resample_audio(audio_segment, target_sample_rate): | |
"""Resample an audio segment to the target sample rate.""" | |
return audio_segment.set_frame_rate(target_sample_rate) | |
def detect_language(audio_file): | |
"""Detect the language of the audio file.""" | |
if audio_file is None: | |
return "Error: No audio file uploaded." | |
try: | |
# Convert audio to WAV format | |
wav_path = convert_to_wav(audio_file) | |
# Define device and compute type for faster-whisper | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
compute_type = "float32" if device == "cuda" else "int8" | |
# Load the faster-whisper model for language detection | |
model = WhisperModel(MODELS["Faster Whisper Large v3"], device=device, compute_type=compute_type) | |
# Detect the language using faster-whisper | |
segments, info = model.transcribe(wav_path, task="translate", language=None) | |
detected_language_code = info.language | |
# Get the full language name from the code | |
detected_language = CODE_TO_LANGUAGE_NAME.get(detected_language_code, "Unknown Language") | |
# Clean up temporary WAV file | |
os.remove(wav_path) | |
return f"Detected Language: {detected_language}" | |
except Exception as e: | |
logger.error(f"Error in detect_language: {str(e)}") | |
return f"Error: {str(e)}" | |
def remove_silence(audio_file, silence_threshold=-40, min_silence_len=500): | |
""" | |
Remove silence from the audio file using AI-based silence detection. | |
Args: | |
audio_file (str): Path to the input audio file. | |
silence_threshold (int): Silence threshold in dB. Default is -40 dB. | |
min_silence_len (int): Minimum length of silence to remove in milliseconds. Default is 500 ms. | |
Returns: | |
str: Path to the output audio file with silence removed. | |
""" | |
if audio_file is None: | |
return None | |
try: | |
# Convert audio to WAV format | |
wav_path = convert_to_wav(audio_file) | |
# Load the audio file | |
audio = AudioSegment.from_file(wav_path) | |
# Detect silent chunks | |
silent_chunks = silence.detect_silence( | |
audio, | |
min_silence_len=min_silence_len, | |
silence_thresh=silence_threshold | |
) | |
# Remove silent chunks | |
non_silent_audio = AudioSegment.empty() | |
start = 0 | |
for chunk in silent_chunks: | |
non_silent_audio += audio[start:chunk[0]] # Add non-silent part | |
start = chunk[1] # Move to the end of the silent chunk | |
non_silent_audio += audio[start:] # Add the remaining part | |
# Export the processed audio | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output: | |
output_path = temp_output.name | |
non_silent_audio.export(output_path, format="wav") | |
# Clean up temporary WAV file | |
os.remove(wav_path) | |
return output_path | |
except Exception as e: | |
logger.error(f"Error in remove_silence: {str(e)}") | |
return f"Error: {str(e)}" | |
def detect_and_trim_audio(main_audio, target_audio, threshold=0.5): | |
""" | |
Detect the target audio in the main audio and trim the main audio to include only the detected segments. | |
Args: | |
main_audio (str): Path to the main audio file. | |
target_audio (str): Path to the target audio file. | |
threshold (float): Detection threshold (0 to 1). Higher values mean stricter detection. | |
Returns: | |
str: Path to the trimmed audio file. | |
str: Detected timestamps in the format "start-end (in seconds)". | |
""" | |
if main_audio is None or target_audio is None: | |
return None, "Error: Please upload both main and target audio files." | |
try: | |
# Convert audio files to WAV format | |
main_wav_path = convert_to_wav(main_audio) | |
target_wav_path = convert_to_wav(target_audio) | |
# Load audio files | |
main_rate, main_data = wavfile.read(main_wav_path) | |
target_rate, target_data = wavfile.read(target_wav_path) | |
# Ensure both audio files have the same sample rate | |
if main_rate != target_rate: | |
logger.warning(f"Sample rates differ: main_audio={main_rate}, target_audio={target_rate}. Resampling target audio.") | |
target_segment = AudioSegment.from_file(target_wav_path) | |
target_segment = resample_audio(target_segment, main_rate) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_resampled: | |
resampled_path = temp_resampled.name | |
target_segment.export(resampled_path, format="wav") | |
target_rate, target_data = wavfile.read(resampled_path) | |
# Normalize audio data | |
main_data = main_data.astype(np.float32) / np.iinfo(main_data.dtype).max | |
target_data = target_data.astype(np.float32) / np.iinfo(target_data.dtype).max | |
# Perform cross-correlation to detect the target audio in the main audio | |
correlation = correlate(main_data, target_data, mode='valid') | |
correlation = np.abs(correlation) | |
max_corr = np.max(correlation) | |
# Find the peak in the cross-correlation result | |
peak_index = np.argmax(correlation) | |
peak_value = correlation[peak_index] | |
# Check if the peak value exceeds the threshold | |
if peak_value < threshold * max_corr: | |
return None, "Error: Target audio not detected in the main audio." | |
# Calculate the start and end times of the target audio in the main audio | |
start_time = peak_index / main_rate | |
end_time = (peak_index + len(target_data)) / main_rate | |
# Trim the main audio to include only the detected segment | |
main_audio_segment = AudioSegment.from_file(main_wav_path) | |
start_ms = int(start_time * 1000) | |
end_ms = int(end_time * 1000) | |
trimmed_audio = main_audio_segment[start_ms:end_ms] | |
# Export the trimmed audio | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output: | |
output_path = temp_output.name | |
trimmed_audio.export(output_path, format="wav") | |
# Format timestamps | |
timestamps_str = f"{start_time:.2f}-{end_time:.2f}" | |
# Clean up temporary WAV files | |
os.remove(main_wav_path) | |
os.remove(target_wav_path) | |
if 'resampled_path' in locals(): | |
os.remove(resampled_path) | |
return output_path, timestamps_str | |
except Exception as e: | |
logger.error(f"Error in detect_and_trim_audio: {str(e)}") | |
return None, f"Error: {str(e)}" | |
def transcribe_audio(audio_file, language="Auto Detect", model_size="Faster Whisper Large v3"): | |
"""Transcribe the audio file.""" | |
if audio_file is None: | |
return "Error: No audio file uploaded." | |
try: | |
# Convert audio to WAV format | |
wav_path = convert_to_wav(audio_file) | |
# Convert audio to 16kHz mono for better compatibility | |
audio = AudioSegment.from_file(wav_path) | |
audio = audio.set_frame_rate(16000).set_channels(1) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_processed: | |
processed_audio_path = temp_processed.name | |
audio.export(processed_audio_path, format="wav") | |
# Load the appropriate model | |
if model_size == "Faster Whisper Large v3": | |
# Define device and compute type for faster-whisper | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
compute_type = "float32" if device == "cuda" else "int8" | |
# Use faster-whisper for the Systran model | |
model = WhisperModel(MODELS[model_size], device=device, compute_type=compute_type) | |
segments, info = model.transcribe( | |
processed_audio_path, | |
task="transcribe", | |
word_timestamps=True, | |
repetition_penalty=1.1, | |
temperature=[0.0, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0], | |
) | |
transcription = " ".join([segment.text for segment in segments]) | |
detected_language_code = info.language | |
detected_language = CODE_TO_LANGUAGE_NAME.get(detected_language_code, "Unknown Language") | |
else: | |
# Use the standard Whisper model | |
model = whisper.load_model(MODELS[model_size]) | |
# Transcribe the audio | |
if language == "Auto Detect": | |
result = model.transcribe(processed_audio_path, fp16=False) # Auto-detect language | |
detected_language_code = result.get("language", "unknown") | |
detected_language = CODE_TO_LANGUAGE_NAME.get(detected_language_code, "Unknown Language") | |
else: | |
language_code = LANGUAGE_NAME_TO_CODE.get(language, "en") # Default to English if not found | |
result = model.transcribe(processed_audio_path, language=language_code, fp16=False) | |
detected_language = language | |
transcription = result["text"] | |
# Clean up processed audio file | |
os.remove(processed_audio_path) | |
os.remove(wav_path) | |
# Return transcription and detected language | |
return f"Detected Language: {detected_language}\n\nTranscription:\n{transcription}" | |
except Exception as e: | |
logger.error(f"Error in transcribe_audio: {str(e)}") | |
return f"Error: {str(e)}" | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Audio Processing Tool") | |
with gr.Tab("Detect Language"): | |
gr.Markdown("Upload an audio file to detect its language.") | |
detect_audio_input = gr.Audio(type="filepath", label="Upload Audio File") | |
detect_language_output = gr.Textbox(label="Detected Language") | |
detect_button = gr.Button("Detect Language") | |
with gr.Tab("Transcribe Audio"): | |
gr.Markdown("Upload an audio file, select a language (or choose 'Auto Detect'), and choose a model for transcription.") | |
transcribe_audio_input = gr.Audio(type="filepath", label="Upload Audio File") | |
language_dropdown = gr.Dropdown( | |
choices=list(LANGUAGE_NAME_TO_CODE.keys()), # Full language names | |
label="Select Language", | |
value="Auto Detect" | |
) | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), # Model options | |
label="Select Model", | |
value="Faster Whisper Large v3", # Default to "Faster Whisper Large v3" | |
interactive=True # Allow model selection by default | |
) | |
transcribe_output = gr.Textbox(label="Transcription and Detected Language") | |
transcribe_button = gr.Button("Transcribe Audio") | |
with gr.Tab("Remove Silence"): | |
gr.Markdown("Upload an audio file to remove silence.") | |
silence_audio_input = gr.Audio(type="filepath", label="Upload Audio File") | |
silence_threshold_slider = gr.Slider( | |
minimum=-60, maximum=-20, value=-40, step=1, | |
label="Silence Threshold (dB)", | |
info="Lower values detect quieter sounds as silence." | |
) | |
min_silence_len_slider = gr.Slider( | |
minimum=100, maximum=2000, value=500, step=100, | |
label="Minimum Silence Length (ms)", | |
info="Minimum duration of silence to remove." | |
) | |
silence_output = gr.Audio(label="Processed Audio (Silence Removed)", type="filepath") | |
silence_button = gr.Button("Remove Silence") | |
with gr.Tab("Detect and Trim Audio"): | |
gr.Markdown("Upload a main audio file and a target audio file. The app will detect the target audio in the main audio and trim it.") | |
main_audio_input = gr.Audio(type="filepath", label="Upload Main Audio File") | |
target_audio_input = gr.Audio(type="filepath", label="Upload Target Audio File") | |
threshold_slider = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.5, step=0.1, | |
label="Detection Threshold", | |
info="Higher values mean stricter detection." | |
) | |
trimmed_audio_output = gr.Audio(label="Trimmed Audio", type="filepath") | |
timestamps_output = gr.Textbox(label="Detected Timestamps (in seconds)") | |
detect_button = gr.Button("Detect and Trim") | |
# Link buttons to functions | |
detect_button.click(detect_language, inputs=detect_audio_input, outputs=detect_language_output) | |
transcribe_button.click( | |
transcribe_audio, | |
inputs=[transcribe_audio_input, language_dropdown, model_dropdown], | |
outputs=transcribe_output | |
) | |
silence_button.click( | |
remove_silence, | |
inputs=[silence_audio_input, silence_threshold_slider, min_silence_len_slider], | |
outputs=silence_output | |
) | |
detect_button.click( | |
detect_and_trim_audio, | |
inputs=[main_audio_input, target_audio_input, threshold_slider], | |
outputs=[trimmed_audio_output, timestamps_output] | |
) | |
# Launch the Gradio interface | |
demo.launch() |