monitkorn commited on
Commit
353a46c
·
1 Parent(s): 07c3b9c

update model

Browse files
Files changed (1) hide show
  1. app.py +10 -22
app.py CHANGED
@@ -5,7 +5,6 @@ import tempfile
5
  import requests
6
  from moviepy.editor import VideoFileClip
7
 
8
- # Ensure the official OpenAI Whisper package is installed (supports load_model)
9
  try:
10
  import whisper
11
  if not hasattr(whisper, 'load_model'):
@@ -21,29 +20,26 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
21
  from huggingface_hub import login
22
  import gradio as gr
23
 
24
- # Authenticate with Hugging Face (token via HF_TOKEN env var)
25
 
26
-
27
-
28
- # Device setup (GPU if available)
29
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
 
31
  def load_models():
32
- # Load Whisper directly on the target device
33
- whisper_model = whisper.load_model('base', device=device)
34
  processor = Wav2Vec2Processor.from_pretrained(
35
- 'jonatasgrosman/wav2vec2-large-xlsr-53-english'
36
  )
37
  accent_model = Wav2Vec2ForSequenceClassification.from_pretrained(
38
- 'jonatasgrosman/wav2vec2-large-xlsr-53-english'
39
  ).to(device)
 
 
 
40
  return whisper_model, processor, accent_model
41
 
42
  whisper_model, processor, accent_model = load_models()
43
 
44
- # Main analysis function
45
  def analyze(video_url: str):
46
- # Download video to temp file
47
  with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_vid:
48
  response = requests.get(video_url, stream=True)
49
  response.raise_for_status()
@@ -52,23 +48,19 @@ def analyze(video_url: str):
52
  tmp_vid.write(chunk)
53
  video_path = tmp_vid.name
54
 
55
- # Extract audio
56
  audio_path = video_path.replace('.mp4', '.wav')
57
  clip = VideoFileClip(video_path)
58
  clip.audio.write_audiofile(audio_path, verbose=False, logger=None)
59
  clip.close()
60
 
61
- # Load audio waveform
62
  speech, sr = librosa.load(audio_path, sr=16000)
63
 
64
- # Transcribe with Whisper (model on correct device)
65
  result = whisper_model.transcribe(speech)
66
  transcript = result.get('text', '')
67
  lang = result.get('language', 'unknown')
68
  if lang != 'en':
69
  transcript = f"[Non-English detected: {lang}]\n" + transcript
70
 
71
- # Accent classification
72
  inputs = processor(speech, sampling_rate=sr, return_tensors='pt', padding=True)
73
  input_values = inputs.input_values.to(device)
74
  attention_mask = inputs.attention_mask.to(device)
@@ -76,20 +68,17 @@ def analyze(video_url: str):
76
  logits = accent_model(input_values=input_values, attention_mask=attention_mask).logits
77
  probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist()
78
 
79
- # Map default LABEL_x to human-readable accents
80
  accent_labels = [
81
  'American', 'Australian', 'British', 'Canadian', 'Indian',
82
  'Irish', 'New Zealander', 'South African', 'Welsh'
83
- ] # ensure this matches model output order
84
  accent_probs = [(accent_labels[i], probs[i] * 100) for i in range(len(probs))]
85
  accent_probs.sort(key=lambda x: x[1], reverse=True)
86
  top_accent, top_conf = accent_probs[0]
87
 
88
- # Prepare DataFrame
89
  df = pd.DataFrame(accent_probs, columns=['Accent', 'Confidence (%)'])
90
  df = pd.DataFrame(accent_probs, columns=['Accent', 'Confidence (%)'])
91
 
92
- # Cleanup temp files
93
  try:
94
  os.remove(video_path)
95
  os.remove(audio_path)
@@ -98,7 +87,6 @@ def analyze(video_url: str):
98
 
99
  return top_accent, f"{top_conf:.2f}%", df
100
 
101
- # Gradio interface
102
  interface = gr.Interface(
103
  fn=analyze,
104
  inputs=gr.Textbox(label='Video URL', placeholder='Enter public MP4 URL'),
@@ -109,7 +97,7 @@ interface = gr.Interface(
109
  gr.Dataframe(label='All Accent Probabilities')
110
  ],
111
  title='English Accent Detector',
112
- description='Paste a Loom or direct MP4 URL to extract, transcribe, and classify English accents (uses GPU if available).',
113
  allow_flagging='never'
114
  )
115
 
 
5
  import requests
6
  from moviepy.editor import VideoFileClip
7
 
 
8
  try:
9
  import whisper
10
  if not hasattr(whisper, 'load_model'):
 
20
  from huggingface_hub import login
21
  import gradio as gr
22
 
 
23
 
24
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+ device = 'cpu'
 
 
26
 
27
  def load_models():
28
+ whisper_model = whisper.load_model('tiny', device=device)
 
29
  processor = Wav2Vec2Processor.from_pretrained(
30
+ 'jonatasgrosman/wav2vec2-large-english'
31
  )
32
  accent_model = Wav2Vec2ForSequenceClassification.from_pretrained(
33
+ 'jonatasgrosman/wav2vec2-large-english'
34
  ).to(device)
35
+ accent_model = torch.quantization.quantize_dynamic(
36
+ accent_model, {torch.nn.Linear}, dtype=torch.qint8
37
+ )
38
  return whisper_model, processor, accent_model
39
 
40
  whisper_model, processor, accent_model = load_models()
41
 
 
42
  def analyze(video_url: str):
 
43
  with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_vid:
44
  response = requests.get(video_url, stream=True)
45
  response.raise_for_status()
 
48
  tmp_vid.write(chunk)
49
  video_path = tmp_vid.name
50
 
 
51
  audio_path = video_path.replace('.mp4', '.wav')
52
  clip = VideoFileClip(video_path)
53
  clip.audio.write_audiofile(audio_path, verbose=False, logger=None)
54
  clip.close()
55
 
 
56
  speech, sr = librosa.load(audio_path, sr=16000)
57
 
 
58
  result = whisper_model.transcribe(speech)
59
  transcript = result.get('text', '')
60
  lang = result.get('language', 'unknown')
61
  if lang != 'en':
62
  transcript = f"[Non-English detected: {lang}]\n" + transcript
63
 
 
64
  inputs = processor(speech, sampling_rate=sr, return_tensors='pt', padding=True)
65
  input_values = inputs.input_values.to(device)
66
  attention_mask = inputs.attention_mask.to(device)
 
68
  logits = accent_model(input_values=input_values, attention_mask=attention_mask).logits
69
  probs = torch.softmax(logits, dim=-1).squeeze().cpu().tolist()
70
 
 
71
  accent_labels = [
72
  'American', 'Australian', 'British', 'Canadian', 'Indian',
73
  'Irish', 'New Zealander', 'South African', 'Welsh'
74
+ ]
75
  accent_probs = [(accent_labels[i], probs[i] * 100) for i in range(len(probs))]
76
  accent_probs.sort(key=lambda x: x[1], reverse=True)
77
  top_accent, top_conf = accent_probs[0]
78
 
 
79
  df = pd.DataFrame(accent_probs, columns=['Accent', 'Confidence (%)'])
80
  df = pd.DataFrame(accent_probs, columns=['Accent', 'Confidence (%)'])
81
 
 
82
  try:
83
  os.remove(video_path)
84
  os.remove(audio_path)
 
87
 
88
  return top_accent, f"{top_conf:.2f}%", df
89
 
 
90
  interface = gr.Interface(
91
  fn=analyze,
92
  inputs=gr.Textbox(label='Video URL', placeholder='Enter public MP4 URL'),
 
97
  gr.Dataframe(label='All Accent Probabilities')
98
  ],
99
  title='English Accent Detector',
100
+ description='Paste a direct MP4 URL to extract, transcribe, and classify English accents. It is a bit slow since we run Whisper and Wav2Vec2 models on CPU. Please test with short videos.',
101
  allow_flagging='never'
102
  )
103