File size: 4,015 Bytes
7f17e42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import sys
import subprocess
import tempfile
import requests
from moviepy.editor import VideoFileClip

# Ensure the official OpenAI Whisper package is installed (supports load_model)
try:
    import whisper
    if not hasattr(whisper, 'load_model'):
        raise ImportError
except ImportError:
    subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "openai-whisper"], check=True)
    import whisper

import torch
import librosa
import pandas as pd
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from huggingface_hub import login
import gradio as gr

# Authenticate with Hugging Face (token via HF_TOKEN env var)



# Device setup (GPU if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_models():
    # Load Whisper directly on the target device
    whisper_model = whisper.load_model('base', device=device)
    processor = Wav2Vec2Processor.from_pretrained(
        'jonatasgrosman/wav2vec2-large-xlsr-53-english'
    )
    accent_model = Wav2Vec2ForSequenceClassification.from_pretrained(
        'jonatasgrosman/wav2vec2-large-xlsr-53-english'
    ).to(device)
    return whisper_model, processor, accent_model

whisper_model, processor, accent_model = load_models()

# Main analysis function
def analyze(video_url: str):
    # Download video to temp file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_vid:
        response = requests.get(video_url, stream=True)
        response.raise_for_status()
        for chunk in response.iter_content(chunk_size=1024 * 1024):
            if chunk:
                tmp_vid.write(chunk)
        video_path = tmp_vid.name

    # Extract audio
    audio_path = video_path.replace('.mp4', '.wav')
    clip = VideoFileClip(video_path)
    clip.audio.write_audiofile(audio_path, verbose=False, logger=None)
    clip.close()

    # Load audio waveform
    speech, sr = librosa.load(audio_path, sr=16000)

    # Transcribe with Whisper (model on correct device)
    result = whisper_model.transcribe(speech)
    transcript = result.get('text', '')
    lang = result.get('language', 'unknown')
    if lang != 'en':
        transcript = f"[Non-English detected: {lang}]\n" + transcript

        # Accent classification
    inputs = processor(speech, sampling_rate=sr, return_tensors='pt', padding=True)
    input_values = inputs.input_values.to(device)
    attention_mask = inputs.attention_mask.to(device)
    with torch.no_grad():
        logits = accent_model(input_values=input_values, attention_mask=attention_mask).logits
        probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist()

    # Map default LABEL_x to human-readable accents
    accent_labels = [
        'American', 'Australian', 'British', 'Canadian', 'Indian',
        'Irish', 'New Zealander', 'South African', 'Welsh'
    ]  # ensure this matches model output order
    accent_probs = [(accent_labels[i], probs[i] * 100) for i in range(len(probs))]
    accent_probs.sort(key=lambda x: x[1], reverse=True)
    top_accent, top_conf = accent_probs[0]

    # Prepare DataFrame
    df = pd.DataFrame(accent_probs, columns=['Accent', 'Confidence (%)'])
    df = pd.DataFrame(accent_probs, columns=['Accent', 'Confidence (%)'])

    # Cleanup temp files
    try:
        os.remove(video_path)
        os.remove(audio_path)
    except:
        pass

    return top_accent, f"{top_conf:.2f}%", df

# Gradio interface
interface = gr.Interface(
    fn=analyze,
    inputs=gr.Textbox(label='Video URL', placeholder='Enter public MP4 URL'),
    outputs=[
        # gr.Textbox(label='Transcript'),
        gr.Textbox(label='Predicted Accent'),
        gr.Textbox(label='Accent Confidence'),
        gr.Dataframe(label='All Accent Probabilities')
    ],
    title='English Accent Detector',
    description='Paste a Loom or direct MP4 URL to extract, transcribe, and classify English accents (uses GPU if available).',
    allow_flagging='never'
)

if __name__ == '__main__':
    interface.launch()