import gradio as gr from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor import requests import os from moviepy.editor import VideoFileClip import tempfile import re from urllib.parse import urlparse from gradio import Progress from pathlib import Path import torch import shutil # Import shutil for explicit temporary directory cleanup import soundfile as sf # Import soundfile for explicit audio loading # Load the audio classification model for English accents pipe = pipeline("audio-classification", model="dima806/english_accents_classification") # Load the language detection model language_detector = pipeline("text-classification", model="alexneakameni/language_detection") # Load a small ASR (Automatic Speech Recognition) model for transcribing audio clips # This is used to get text from audio for language detection. # Using 'openai/whisper-tiny.en' for a faster, English-focused transcription. # Ensure to move model to GPU if available for faster inference. device = 0 if torch.cuda.is_available() else -1 # Corrected ASR model ID to a valid Hugging Face model asr_model_id = "openai/whisper-tiny.en" asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_id) asr_processor = AutoProcessor.from_pretrained(asr_model_id) asr_pipe = pipeline( "automatic-speech-recognition", model=asr_model, tokenizer=asr_processor.tokenizer, feature_extractor=asr_processor.feature_extractor, device=device ) def is_valid_url(url): """ Checks if the given URL is valid and from allowed domains (MP4, Loom, or Google Drive). Args: url (str): The URL to validate. Returns: bool: True if the URL is valid and allowed, False otherwise. """ if not url: return False try: result = urlparse(url) if not all([result.scheme, result.netloc]): return False allowed_domains = [ 'loom.com', 'cdn.loom.com', 'www.dropbox.com', 'dl.dropboxusercontent.com', 'drive.google.com' # Added Google Drive domain ] # Check if the domain is in our allowed list is_allowed_domain = any(domain in result.netloc.lower() for domain in allowed_domains) # Check if the path part of the URL ends with .mp4 ends_with_mp4 = result.path.lower().endswith('.mp4') if is_allowed_domain: if ends_with_mp4: return True elif 'drive.google.com' in result.netloc.lower(): # Check for typical Google Drive patterns for shared files or download links return '/file/d/' in result.path or '/uc' in result.path elif any(domain in result.netloc.lower() for domain in ['loom.com', 'cdn.loom.com']): return True # Allow Loom URLs even if they don't end in .mp4 elif ends_with_mp4: # Allow direct .mp4 links from other domains if they end with .mp4 return True return False except Exception: return False def is_valid_file(file_obj): """ Checks if the uploaded file object represents a valid video file format. Args: file_obj (gr.File): The Gradio file object. Returns: bool: True if the file is a supported video format, False otherwise. """ if not file_obj: return False # Get the file extension from the uploaded file object's name file_path = file_obj.name # Check if the file extension is one of the supported video formats return Path(file_path).suffix.lower() in ['.mp4', '.mov', '.avi', '.mkv'] def download_file(url, save_path, progress=Progress()): """ Downloads a video file from a given URL to a specified path. Raises ValueError if the URL is invalid, ConnectionError if download fails. Args: url (str): The URL of the video to download. save_path (str): The local path to save the downloaded video. progress (gradio.Progress): Gradio progress tracker for UI updates. """ if not is_valid_url(url): raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.") response = requests.get(url, stream=True) # Check if the download was successful (HTTP status code 200) if response.status_code != 200: raise ConnectionError(f"Failed to download video (HTTP {response.status_code})") # Get the total size of the file for progress tracking total_size = int(response.headers.get('content-length', 0)) downloaded = 0 # Write the downloaded content to the specified save path in chunks with open(save_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): if chunk: # Filter out keep-alive new chunks f.write(chunk) downloaded += len(chunk) if total_size > 0: # Update progress bar based on downloaded percentage progress(downloaded / total_size, desc="📥 Downloading video...") else: # If total size is unknown, just show a general downloading message progress(0, desc="📥 Downloading video (size unknown)...") def extract_audio_full(video_path, progress=Progress()): """ Extracts the full duration of audio from a video file and saves it as a WAV file. Uses tempfile.NamedTemporaryFile to ensure the file persists for Gradio. Args: video_path (str): Path to the input video file. progress (gradio.Progress): Gradio progress tracker for UI updates. Returns: str: The path to the extracted audio file. """ try: progress(0, desc="🔊 Extracting full audio for playback...") video = VideoFileClip(video_path) # Create a temporary WAV file that Gradio can manage temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) audio_path = temp_audio_file.name temp_audio_file.close() # Close the file handle immediately so moviepy can write to it audio_clip = video.audio audio_clip.write_audiofile(audio_path, fps=16000, logger=None) video.close() audio_clip.close() progress(1.0) return audio_path except Exception as e: raise Exception(f"Full audio extraction failed: {str(e)}") def extract_audio_clip(video_path, audio_path, duration, progress=Progress()): """ Extracts a specified duration of audio from a video file and saves it as a WAV file. Args: video_path (str): Path to the input video file. audio_path (str): Path to save the extracted audio WAV file. duration (int): The duration of audio to extract in seconds. progress (gradio.Progress): Gradio progress tracker for UI updates. Returns: str: The path to the extracted audio file. """ try: progress(0, desc=f"🔊 Extracting {duration} seconds of audio for analysis...") video = VideoFileClip(video_path) # Ensure the subclip duration does not exceed the video's actual duration clip_duration = min(duration, video.duration) audio_clip = video.audio.subclip(0, clip_duration) audio_clip.write_audiofile(audio_path, fps=16000, logger=None) video.close() audio_clip.close() progress(1.0) return audio_path except Exception as e: raise Exception(f"Audio clip extraction failed: {str(e)}") def transcribe_audio(audio_path_clip, progress=Progress()): """ Transcribes a short audio clip to text using the ASR pipeline. Args: audio_path_clip (str): Path to the short audio clip. Returns: str: The transcribed text. """ try: progress(0, desc="📝 Transcribing audio for language detection...") # Load audio using soundfile audio_input, sampling_rate = sf.read(audio_path_clip) # Ensure the audio is mono if the model expects it (Whisper typically does) if audio_input.ndim > 1: audio_input = audio_input.mean(axis=1) # Convert to mono # Process audio with the ASR processor # This handles resampling, padding, and feature extraction to match model requirements inputs = asr_processor(audio_input, sampling_rate=sampling_rate, return_tensors="pt") # Move inputs to the correct device if device != -1: inputs = {k: v.to(device) for k, v in inputs.items()} # Generate transcription with the ASR model with torch.no_grad(): # max_new_tokens can be adjusted based on expected transcription length # For short clips (15s), 128 is usually more than enough output_tokens = asr_model.generate(**inputs, max_new_tokens=128) text = asr_processor.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0] progress(1.0) return text except Exception as e: print(f"Transcription failed: {e}") return "" # Return empty string on failure def classify_audio(audio_path, progress=Progress()): """ Classifies the accent in an audio file using the pre-loaded Hugging Face pipeline. Args: audio_path (str): Path to the input audio file. Returns: list: A list of dictionaries containing accent labels and confidence scores. """ try: progress(0, desc="🔍 Analyzing accent - please be patient...") result = pipe(audio_path) progress(1.0) # Mark completion return result except Exception as e: raise Exception(f"Classification failed: {str(e)}") def process_video_unified(video_source, analysis_duration, progress=Progress()): """ Processes either a video URL or an uploaded video file to classify accent. Includes language detection before accent classification. Args: video_source (str or gr.File): The input, either a URL string or a Gradio File object. analysis_duration (int): The duration of audio to analyze for accent classification in seconds. progress (gradio.Progress): Gradio progress tracker for UI updates. Returns: tuple: (language_status_html, html_output, audio_path, error_flag) language_status_html (str): HTML string displaying language detection status. html_output (str): HTML string displaying accent results or error. audio_path (str or None): Path to extracted full audio if successful, else None. error_flag (bool): True if an error occurred, False otherwise. """ temp_dir = None full_audio_path = None # Initialize to None try: temp_dir = tempfile.mkdtemp() # Create temp dir for intermediate files (video, clipped audio) video_path = os.path.join(temp_dir, "video.mp4") # Determine if input is a URL string or an uploaded Gradio File object if isinstance(video_source, str) and video_source.startswith(('http://', 'https://')): if not is_valid_url(video_source): raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.") download_file(video_source, video_path, progress) elif hasattr(video_source, 'name'): if not is_valid_file(video_source): raise ValueError("Invalid file format. Please upload a video file (MP4)") with open(video_source.name, 'rb') as src_file: with open(video_path, 'wb') as dest_file: dest_file.write(src_file.read()) else: raise ValueError("Unsupported input type. Please provide a video URL or upload a file.") # Verify that the video file exists after download/upload if not os.path.exists(video_path): raise Exception("Video processing failed: Video file not found after download/upload.") # Extract full audio for playback using tempfile.NamedTemporaryFile full_audio_path = extract_audio_full(video_path, progress) # Extract a short clip for transcription and language detection (e.g., first 15 seconds) transcription_clip_duration = 15 audio_for_transcription_path = os.path.join(temp_dir, "audio_for_transcription.wav") extract_audio_clip(video_path, audio_for_transcription_path, transcription_clip_duration, progress) if not os.path.exists(full_audio_path): raise Exception("Audio extraction failed: Full audio file not found.") if not os.path.exists(audio_for_transcription_path): raise Exception("Audio extraction failed: Clipped audio for transcription not found.") # Transcribe the short audio clip transcribed_text = transcribe_audio(audio_for_transcription_path, progress) if not transcribed_text.strip(): language_status_html = "
⚠️ Could not transcribe audio for language detection. Please ensure audio is clear.
" # If transcription fails, we can't detect language, so we'll proceed with accent classification # but provide a warning. Or, you could choose to stop here. For now, let's proceed. else: # Perform language detection lang_detection_result = language_detector(transcribed_text) detected_language = lang_detection_result[0]['label'] lang_confidence = lang_detection_result[0]['score'] # Check if detected language is English or eng_Latn with a reasonable confidence if (detected_language.lower() == 'english' or detected_language.lower() == 'eng_latn') and lang_confidence > 0.7: # Added 'eng_Latn' check language_status_html = f"✅ Verified English Language (Confidence: {lang_confidence*100:.2f}%)
" else: language_status_html = f"⚠️ Detected language: {detected_language.capitalize()} (Confidence: {lang_confidence*100:.2f}%). Please provide English audio for accent classification.
" # If not English, return early with an error message and skip accent classification return language_status_html, "", full_audio_path, True # Set error flag to True # Extract audio clip for accent classification (based on analysis_duration slider) audio_for_classification_path = os.path.join(temp_dir, "audio_for_classification.wav") extract_audio_clip(video_path, audio_for_classification_path, analysis_duration, progress) if not os.path.exists(audio_for_classification_path): raise Exception("Audio extraction failed: Clipped audio for classification not found.") # Classify the extracted audio for accent result = classify_audio(audio_for_classification_path, progress) if not result: return language_status_html, "⚠️ No accent prediction returned
", full_audio_path, True # Build results table for display # Adjusted table width to 'fit-content' and individual column widths table = """Rank | Accent | Confidence (%) | Score |
---|---|---|---|
#{i+1} | {label} |
{score_formatted_percent}
|
{score_formatted_raw} |
#{i+1} | {label} |
{score_formatted_percent}
|
{score_formatted_raw} |
⚠️ Error: {str(e)}
", None, True finally: # Explicitly clean up the temporary directory created for intermediate files. # The full_audio_path is now managed by NamedTemporaryFile and Gradio. if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir) # Define a custom Gradio theme for improved aesthetics # This theme inherits from the default theme and overrides specific properties. my_theme = gr.themes.Default().set( # Background colors: A light grey for the primary background, white for inner blocks background_fill_primary="#f0f2f5", background_fill_secondary="#ffffff", # Border for a cleaner look border_color_primary="#e0e0e0", # Button styling for a consistent look # Changed primary button color to a darker, muted green button_primary_background_fill="#4CAF50", # A standard green button_primary_background_fill_hover="#66BB6A", # A slightly lighter green on hover button_primary_text_color="#ffffff", # White text for primary buttons # Changed secondary button color to a darker, muted green button_secondary_background_fill="#4CAF50", # A standard green button_secondary_background_fill_hover="#66BB6A", # A slightly lighter green on hover button_secondary_text_color="#ffffff", # White text for secondary buttons # Accent color for sliders and other accent elements color_accent="#2196F3", # Blue for accent elements like sliders color_accent_soft="#BBDEFB", # Lighter blue for soft accent elements ) # Gradio app interface definition with gr.Blocks(theme=my_theme) as app: # Apply the custom theme here gr.Markdown("""Analyze English accents from either:
The accent analysis will be performed on the first 60 seconds of audio by default, after language detection.
The analysis may take some time depending on the video size and your chosen analysis duration. Please be patient while we process your video.
Supported file formats: MP4
Note: This application requires FFmpeg to be installed on your system to process video and audio files.
⚠️ Unexpected Error: {str(e)}
", visible=True) ) def clear_inputs(): return ( "", # url_input None, # video_input 60, # analysis_duration (reset to default) "Waiting for video input...", # status_box gr.Slider(visible=False, value=0), # progress_bar (hidden and reset) "", # language_status_html (clear) "", # output_html (clear) gr.Audio(visible=True, value=None, label="Extracted Audio (Full Duration)"), "" # error_output (clear) ) submit_btn.click( fn=unified_processing_fn, inputs=[url_input, video_input, analysis_duration], outputs=[status_box, progress_bar, language_status_html, output_html, audio_player, error_output], api_name="classify_video" ) clear_btn.click( fn=clear_inputs, inputs=[], outputs=[url_input, video_input, analysis_duration, status_box, progress_bar, language_status_html, output_html, audio_player, error_output], ) if __name__ == "__main__": app.launch(share=True)