Spaces:
Running
Running
File size: 28,055 Bytes
dcdec88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 |
import gradio as gr
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import requests
import os
from moviepy.editor import VideoFileClip
import tempfile
import re
from urllib.parse import urlparse
from gradio import Progress
from pathlib import Path
import torch
import shutil # Import shutil for explicit temporary directory cleanup
import soundfile as sf # Import soundfile for explicit audio loading
# Load the audio classification model for English accents
pipe = pipeline("audio-classification", model="dima806/english_accents_classification")
# Load the language detection model
language_detector = pipeline("text-classification", model="alexneakameni/language_detection")
# Load a small ASR (Automatic Speech Recognition) model for transcribing audio clips
# This is used to get text from audio for language detection.
# Using 'openai/whisper-tiny.en' for a faster, English-focused transcription.
# Ensure to move model to GPU if available for faster inference.
device = 0 if torch.cuda.is_available() else -1
# Corrected ASR model ID to a valid Hugging Face model
asr_model_id = "openai/whisper-tiny.en"
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_id)
asr_processor = AutoProcessor.from_pretrained(asr_model_id)
asr_pipe = pipeline(
"automatic-speech-recognition",
model=asr_model,
tokenizer=asr_processor.tokenizer,
feature_extractor=asr_processor.feature_extractor,
device=device
)
def is_valid_url(url):
"""
Checks if the given URL is valid and from allowed domains (MP4, Loom, or Google Drive).
Args:
url (str): The URL to validate.
Returns:
bool: True if the URL is valid and allowed, False otherwise.
"""
if not url:
return False
try:
result = urlparse(url)
if not all([result.scheme, result.netloc]):
return False
allowed_domains = [
'loom.com',
'cdn.loom.com',
'www.dropbox.com',
'dl.dropboxusercontent.com',
'drive.google.com' # Added Google Drive domain
]
# Check if the domain is in our allowed list
is_allowed_domain = any(domain in result.netloc.lower() for domain in allowed_domains)
# Check if the path part of the URL ends with .mp4
ends_with_mp4 = result.path.lower().endswith('.mp4')
if is_allowed_domain:
if ends_with_mp4:
return True
elif 'drive.google.com' in result.netloc.lower():
# Check for typical Google Drive patterns for shared files or download links
return '/file/d/' in result.path or '/uc' in result.path
elif any(domain in result.netloc.lower() for domain in ['loom.com', 'cdn.loom.com']):
return True # Allow Loom URLs even if they don't end in .mp4
elif ends_with_mp4:
# Allow direct .mp4 links from other domains if they end with .mp4
return True
return False
except Exception:
return False
def is_valid_file(file_obj):
"""
Checks if the uploaded file object represents a valid video file format.
Args:
file_obj (gr.File): The Gradio file object.
Returns:
bool: True if the file is a supported video format, False otherwise.
"""
if not file_obj:
return False
# Get the file extension from the uploaded file object's name
file_path = file_obj.name
# Check if the file extension is one of the supported video formats
return Path(file_path).suffix.lower() in ['.mp4', '.mov', '.avi', '.mkv']
def download_file(url, save_path, progress=Progress()):
"""
Downloads a video file from a given URL to a specified path.
Raises ValueError if the URL is invalid, ConnectionError if download fails.
Args:
url (str): The URL of the video to download.
save_path (str): The local path to save the downloaded video.
progress (gradio.Progress): Gradio progress tracker for UI updates.
"""
if not is_valid_url(url):
raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.")
response = requests.get(url, stream=True)
# Check if the download was successful (HTTP status code 200)
if response.status_code != 200:
raise ConnectionError(f"Failed to download video (HTTP {response.status_code})")
# Get the total size of the file for progress tracking
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
# Write the downloaded content to the specified save path in chunks
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk: # Filter out keep-alive new chunks
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
# Update progress bar based on downloaded percentage
progress(downloaded / total_size, desc="π₯ Downloading video...")
else:
# If total size is unknown, just show a general downloading message
progress(0, desc="π₯ Downloading video (size unknown)...")
def extract_audio_full(video_path, progress=Progress()):
"""
Extracts the full duration of audio from a video file and saves it as a WAV file.
Uses tempfile.NamedTemporaryFile to ensure the file persists for Gradio.
Args:
video_path (str): Path to the input video file.
progress (gradio.Progress): Gradio progress tracker for UI updates.
Returns:
str: The path to the extracted audio file.
"""
try:
progress(0, desc="π Extracting full audio for playback...")
video = VideoFileClip(video_path)
# Create a temporary WAV file that Gradio can manage
temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
audio_path = temp_audio_file.name
temp_audio_file.close() # Close the file handle immediately so moviepy can write to it
audio_clip = video.audio
audio_clip.write_audiofile(audio_path, fps=16000, logger=None)
video.close()
audio_clip.close()
progress(1.0)
return audio_path
except Exception as e:
raise Exception(f"Full audio extraction failed: {str(e)}")
def extract_audio_clip(video_path, audio_path, duration, progress=Progress()):
"""
Extracts a specified duration of audio from a video file and saves it as a WAV file.
Args:
video_path (str): Path to the input video file.
audio_path (str): Path to save the extracted audio WAV file.
duration (int): The duration of audio to extract in seconds.
progress (gradio.Progress): Gradio progress tracker for UI updates.
Returns:
str: The path to the extracted audio file.
"""
try:
progress(0, desc=f"π Extracting {duration} seconds of audio for analysis...")
video = VideoFileClip(video_path)
# Ensure the subclip duration does not exceed the video's actual duration
clip_duration = min(duration, video.duration)
audio_clip = video.audio.subclip(0, clip_duration)
audio_clip.write_audiofile(audio_path, fps=16000, logger=None)
video.close()
audio_clip.close()
progress(1.0)
return audio_path
except Exception as e:
raise Exception(f"Audio clip extraction failed: {str(e)}")
def transcribe_audio(audio_path_clip, progress=Progress()):
"""
Transcribes a short audio clip to text using the ASR pipeline.
Args:
audio_path_clip (str): Path to the short audio clip.
Returns:
str: The transcribed text.
"""
try:
progress(0, desc="π Transcribing audio for language detection...")
# Load audio using soundfile
audio_input, sampling_rate = sf.read(audio_path_clip)
# Ensure the audio is mono if the model expects it (Whisper typically does)
if audio_input.ndim > 1:
audio_input = audio_input.mean(axis=1) # Convert to mono
# Process audio with the ASR processor
# This handles resampling, padding, and feature extraction to match model requirements
inputs = asr_processor(audio_input, sampling_rate=sampling_rate, return_tensors="pt")
# Move inputs to the correct device
if device != -1:
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate transcription with the ASR model
with torch.no_grad():
# max_new_tokens can be adjusted based on expected transcription length
# For short clips (15s), 128 is usually more than enough
output_tokens = asr_model.generate(**inputs, max_new_tokens=128)
text = asr_processor.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
progress(1.0)
return text
except Exception as e:
print(f"Transcription failed: {e}")
return "" # Return empty string on failure
def classify_audio(audio_path, progress=Progress()):
"""
Classifies the accent in an audio file using the pre-loaded Hugging Face pipeline.
Args:
audio_path (str): Path to the input audio file.
Returns:
list: A list of dictionaries containing accent labels and confidence scores.
"""
try:
progress(0, desc="π Analyzing accent - please be patient...")
result = pipe(audio_path)
progress(1.0) # Mark completion
return result
except Exception as e:
raise Exception(f"Classification failed: {str(e)}")
def process_video_unified(video_source, analysis_duration, progress=Progress()):
"""
Processes either a video URL or an uploaded video file to classify accent.
Includes language detection before accent classification.
Args:
video_source (str or gr.File): The input, either a URL string or a Gradio File object.
analysis_duration (int): The duration of audio to analyze for accent classification in seconds.
progress (gradio.Progress): Gradio progress tracker for UI updates.
Returns:
tuple: (language_status_html, html_output, audio_path, error_flag)
language_status_html (str): HTML string displaying language detection status.
html_output (str): HTML string displaying accent results or error.
audio_path (str or None): Path to extracted full audio if successful, else None.
error_flag (bool): True if an error occurred, False otherwise.
"""
temp_dir = None
full_audio_path = None # Initialize to None
try:
temp_dir = tempfile.mkdtemp() # Create temp dir for intermediate files (video, clipped audio)
video_path = os.path.join(temp_dir, "video.mp4")
# Determine if input is a URL string or an uploaded Gradio File object
if isinstance(video_source, str) and video_source.startswith(('http://', 'https://')):
if not is_valid_url(video_source):
raise ValueError("Invalid URL. Only .mp4 files or Loom videos are accepted.")
download_file(video_source, video_path, progress)
elif hasattr(video_source, 'name'):
if not is_valid_file(video_source):
raise ValueError("Invalid file format. Please upload a video file (MP4)")
with open(video_source.name, 'rb') as src_file:
with open(video_path, 'wb') as dest_file:
dest_file.write(src_file.read())
else:
raise ValueError("Unsupported input type. Please provide a video URL or upload a file.")
# Verify that the video file exists after download/upload
if not os.path.exists(video_path):
raise Exception("Video processing failed: Video file not found after download/upload.")
# Extract full audio for playback using tempfile.NamedTemporaryFile
full_audio_path = extract_audio_full(video_path, progress)
# Extract a short clip for transcription and language detection (e.g., first 15 seconds)
transcription_clip_duration = 15
audio_for_transcription_path = os.path.join(temp_dir, "audio_for_transcription.wav")
extract_audio_clip(video_path, audio_for_transcription_path, transcription_clip_duration, progress)
if not os.path.exists(full_audio_path):
raise Exception("Audio extraction failed: Full audio file not found.")
if not os.path.exists(audio_for_transcription_path):
raise Exception("Audio extraction failed: Clipped audio for transcription not found.")
# Transcribe the short audio clip
transcribed_text = transcribe_audio(audio_for_transcription_path, progress)
if not transcribed_text.strip():
language_status_html = "<p style='color: orange; font-weight: bold;'>β οΈ Could not transcribe audio for language detection. Please ensure audio is clear.</p>"
# If transcription fails, we can't detect language, so we'll proceed with accent classification
# but provide a warning. Or, you could choose to stop here. For now, let's proceed.
else:
# Perform language detection
lang_detection_result = language_detector(transcribed_text)
detected_language = lang_detection_result[0]['label']
lang_confidence = lang_detection_result[0]['score']
# Check if detected language is English or eng_Latn with a reasonable confidence
if (detected_language.lower() == 'english' or detected_language.lower() == 'eng_latn') and lang_confidence > 0.7: # Added 'eng_Latn' check
language_status_html = f"<p style='color: green; font-weight: bold;'>β
Verified English Language (Confidence: {lang_confidence*100:.2f}%)</p>"
else:
language_status_html = f"<p style='color: red; font-weight: bold;'>β οΈ Detected language: {detected_language.capitalize()} (Confidence: {lang_confidence*100:.2f}%). Please provide English audio for accent classification.</p>"
# If not English, return early with an error message and skip accent classification
return language_status_html, "", full_audio_path, True # Set error flag to True
# Extract audio clip for accent classification (based on analysis_duration slider)
audio_for_classification_path = os.path.join(temp_dir, "audio_for_classification.wav")
extract_audio_clip(video_path, audio_for_classification_path, analysis_duration, progress)
if not os.path.exists(audio_for_classification_path):
raise Exception("Audio extraction failed: Clipped audio for classification not found.")
# Classify the extracted audio for accent
result = classify_audio(audio_for_classification_path, progress)
if not result:
return language_status_html, "<p style='color: red; font-weight: bold;'>β οΈ No accent prediction returned</p>", full_audio_path, True
# Build results table for display
# Adjusted table width to 'fit-content' and individual column widths
table = """
<table style='width: fit-content; max-width: 100%; border-collapse: collapse; font-family: Arial, sans-serif; margin-top: 1em;'>
<thead>
<tr style='border-bottom: 2px solid #4CAF50; background-color: #f2f2f2;'>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 50px;'>Rank</th>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 100px;'>Accent</th>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 180px;'>Confidence (%)</th>
<th style='text-align:left; padding: 8px; font-size: 1.1em; color: #333; width: auto; min-width: 80px;'>Score</th>
</tr>
</thead>
<tbody>
"""
for i, r in enumerate(result):
label = r['label'].capitalize()
score = r['score']
score_formatted_percent = f"{score * 100:.2f}%"
score_formatted_raw = f"{score:.4f}"
if i == 0:
row = f"""
<tr style='background-color:#d4edda; font-weight: bold; color: #155724;'>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 50px;'>#{i+1}</td>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 100px;'>{label}</td>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 180px;'>
<div style='display: flex; align-items: center;'>
<span style='width: auto; display: inline-block;'>{score_formatted_percent}</span>
<progress value='{score * 100}' max='100' style='width: 100%; margin-left: 10px;'></progress>
</div>
</td>
<td style='padding: 8px; border-bottom: 1px solid #c3e6cb; width: auto; min-width: 80px;'>
<span style='width: auto; display: inline-block;'>{score_formatted_raw}</span>
</td>
</tr>
"""
else:
row = f"""
<tr style='color: #333;'>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 50px;'>#{i+1}</td>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 100px;'>{label}</td>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 180px;'>
<div style='display: flex; align-items: center;'>
<span style='width: auto; display: inline-block;'>{score_formatted_percent}</span>
<progress value='{score * 100}' max='100' style='width: 100%; margin-left: 10px;'></progress>
</div>
</td>
<td style='padding: 8px; border-bottom: 1px solid #ddd; width: auto; min-width: 80px;'>
<span style='display: inline-block;'>{score_formatted_raw}</span>
</td>
</tr>
"""
table += row
table += "</tbody></table>"
top_result = result[0]
html_output = f"""
<div style='font-family: Arial, sans-serif;'>
<h2 style='color: #2E7D32; margin-bottom: 0.5em;'>
π€ Predicted Accent: <span style='font-weight:bold'>{top_result['label'].capitalize()}</span>
<span style='font-size: 0.8em; color: #555; font-weight: normal;'>
(Confidence: {top_result['score']*100:.2f}%)
</span>
</h2>
{table}
</div>
"""
# Return language status, accent results HTML, full audio path, and no error flag
return language_status_html, html_output, full_audio_path, False
except Exception as e:
# If any error occurs, return an error message and set the error flag
return "", f"<p style='color: red; font-weight: bold;'>β οΈ Error: {str(e)}</p>", None, True
finally:
# Explicitly clean up the temporary directory created for intermediate files.
# The full_audio_path is now managed by NamedTemporaryFile and Gradio.
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Define a custom Gradio theme for improved aesthetics
# This theme inherits from the default theme and overrides specific properties.
my_theme = gr.themes.Default().set(
# Background colors: A light grey for the primary background, white for inner blocks
background_fill_primary="#f0f2f5",
background_fill_secondary="#ffffff",
# Border for a cleaner look
border_color_primary="#e0e0e0",
# Button styling for a consistent look
# Changed primary button color to a darker, muted green
button_primary_background_fill="#4CAF50", # A standard green
button_primary_background_fill_hover="#66BB6A", # A slightly lighter green on hover
button_primary_text_color="#ffffff", # White text for primary buttons
# Changed secondary button color to a darker, muted green
button_secondary_background_fill="#4CAF50", # A standard green
button_secondary_background_fill_hover="#66BB6A", # A slightly lighter green on hover
button_secondary_text_color="#ffffff", # White text for secondary buttons
# Accent color for sliders and other accent elements
color_accent="#2196F3", # Blue for accent elements like sliders
color_accent_soft="#BBDEFB", # Lighter blue for soft accent elements
)
# Gradio app interface definition
with gr.Blocks(theme=my_theme) as app: # Apply the custom theme here
gr.Markdown("""
<div style='font-family: Arial, sans-serif;'>
<h1 style='color: #2E7D32;'>π€ English Accent Classifier</h1>
<p>Analyze English accents from either:</p>
<ul>
<li>A video URL (MP4 or Loom videos)</li>
<li>Or upload a video file from your computer</li>
</ul>
<p>The accent analysis will be performed on the first <strong>60 seconds</strong> of audio by default, after language detection.</p>
<p>The analysis may take some time depending on the video size and your chosen analysis duration. Please be patient while we process your video.</p>
<p><strong>Supported file formats:</strong> MP4 </p>
<p style='font-size: 0.9em; color: #666;'>
<strong>Note:</strong> This application requires <a href='https://ffmpeg.org/download.html' target='_blank' style='color: #2E7D32;'>FFmpeg</a> to be installed on your system to process video and audio files.
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
url_input = gr.Textbox(
label="π Video URL (MP4 or Loom)",
placeholder="Paste URL here..."
)
video_input = gr.File(
label="π Upload Video File",
file_types=["video"],
interactive=True
)
with gr.Column(scale=1):
analysis_duration = gr.Slider(
minimum=5,
maximum=120,
step=5,
value=60,
label="Accent Analysis Duration (seconds)",
info="Analyze the first N seconds of audio for accent classification."
)
with gr.Row():
submit_btn = gr.Button("Analyze Video", variant="primary")
clear_btn = gr.Button("Clear Input")
status_box = gr.Textbox(
label="Status",
placeholder="Waiting for video input...",
interactive=False,
visible=True
)
progress_bar = gr.Slider(
visible=False,
label="Processing Progress",
interactive=False
)
# Placing outputs in a new row to allow for better vertical stacking on smaller screens
# and horizontal arrangement on larger screens.
with gr.Row():
# Using gr.Column to contain the language status and audio player
with gr.Column(scale=1, min_width=300): # Added min_width for better control
language_status_html = gr.HTML(label="Language Detection Status", visible=True)
audio_player = gr.Audio(label="Extracted Audio (Full Duration)", visible=True)
# Using gr.Column for the main results table and error output
with gr.Column(scale=2, min_width=400): # Added min_width for better control
output_html = gr.HTML()
error_output = gr.HTML(visible=False)
def unified_processing_fn(video_url, video_file, analysis_duration, progress=Progress()):
video_source = video_url if video_url else video_file
yield (
gr.Textbox(value="β³ Processing started - please be patient...", visible=True),
gr.Slider(visible=True, value=0),
gr.HTML(value="", visible=True), # Clear language status
gr.HTML(value="", visible=False), # Hide previous HTML output
gr.Audio(value=None, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(value="", visible=False) # Hide previous error output
)
try:
lang_status, html, audio_path, error = process_video_unified(video_source, analysis_duration, progress)
if error:
yield (
gr.Textbox(value="β Processing failed", visible=True),
gr.Slider(visible=False),
gr.HTML(value=lang_status, visible=True),
gr.HTML(value="", visible=False),
gr.Audio(value=audio_path, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(value=html, visible=True)
)
else:
yield (
gr.Textbox(value="β
Analysis complete!", visible=True),
gr.Slider(value=1.0, visible=False),
gr.HTML(value=lang_status, visible=True),
gr.HTML(value=html, visible=True),
gr.Audio(value=audio_path, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(visible=False)
)
except Exception as e:
yield (
gr.Textbox(value="β An unexpected error occurred!", visible=True),
gr.Slider(visible=False),
gr.HTML(value="", visible=True),
gr.HTML(value="", visible=False),
gr.Audio(value=None, visible=True, label="Extracted Audio (Full Duration)"),
gr.HTML(value=f"<p style='color: red; font-weight: bold;'>β οΈ Unexpected Error: {str(e)}</p>", visible=True)
)
def clear_inputs():
return (
"", # url_input
None, # video_input
60, # analysis_duration (reset to default)
"Waiting for video input...", # status_box
gr.Slider(visible=False, value=0), # progress_bar (hidden and reset)
"", # language_status_html (clear)
"", # output_html (clear)
gr.Audio(visible=True, value=None, label="Extracted Audio (Full Duration)"),
"" # error_output (clear)
)
submit_btn.click(
fn=unified_processing_fn,
inputs=[url_input, video_input, analysis_duration],
outputs=[status_box, progress_bar, language_status_html, output_html, audio_player, error_output],
api_name="classify_video"
)
clear_btn.click(
fn=clear_inputs,
inputs=[],
outputs=[url_input, video_input, analysis_duration, status_box, progress_bar, language_status_html, output_html, audio_player, error_output],
)
if __name__ == "__main__":
app.launch(share=True)
|