Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| import shutil | |
| import requests | |
| import subprocess | |
| import soundfile as sf | |
| from scipy.signal import resample | |
| from moviepy.editor import VideoFileClip, AudioFileClip | |
| from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, pipeline | |
| # === Constants === | |
| TEMP_VIDEO = "temp_video.mp4" | |
| RAW_AUDIO = "raw_audio_input" | |
| CONVERTED_AUDIO = "converted_audio.wav" | |
| MODEL_REPO = "ylacombe/accent-classifier" | |
| # === load local model | |
| MODEL_DIR = "model" | |
| model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True) | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_DIR) | |
| # === Load models === | |
| # model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO) | |
| # feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_REPO) | |
| whisper = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") | |
| LABELS = [model.config.id2label[i] for i in range(len(model.config.id2label))] | |
| model.eval() | |
| # === Helpers === | |
| def convert_to_wav(input_path, output_path=CONVERTED_AUDIO): | |
| command = ["ffmpeg", "-y", "-i", input_path, output_path] | |
| subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| return output_path | |
| def extract_audio_from_video(video_path, output_path="extracted_audio.wav"): | |
| clip = VideoFileClip(video_path) | |
| if clip.audio is None: | |
| raise ValueError("No audio stream found in video.") | |
| clip.audio.write_audiofile(output_path) | |
| return output_path | |
| def download_video(url, filename=TEMP_VIDEO): | |
| temp_download = "raw_download.mp4" | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| r = requests.get(url, headers=headers, stream=True, timeout=15) | |
| r.raise_for_status() | |
| if not r.headers.get("Content-Type", "").startswith("video/"): | |
| raise RuntimeError(f"URL is not a video. Content-Type: {r.headers.get('Content-Type')}") | |
| with open(temp_download, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| ffmpeg_cmd = [ | |
| "ffmpeg", "-y", "-i", temp_download, | |
| "-c", "copy", "-movflags", "+faststart", filename | |
| ] | |
| result = subprocess.run(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| if result.returncode != 0 or not os.path.exists(filename) or os.path.getsize(filename) == 0: | |
| raise RuntimeError("FFmpeg failed to process the video.") | |
| os.remove(temp_download) | |
| return filename | |
| def classify_accent(audio_path): | |
| waveform, sr = sf.read(audio_path) | |
| if len(waveform.shape) > 1: | |
| waveform = waveform.mean(axis=1) | |
| if sr != 16000: | |
| num_samples = int(len(waveform) * 16000 / sr) | |
| waveform = resample(waveform, num_samples) | |
| sr = 16000 | |
| inputs = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits[0] | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| top_idx = torch.argmax(probs).item() | |
| top_label = LABELS[top_idx] | |
| top_conf = round(probs[top_idx].item(), 4) | |
| top5 = torch.topk(probs, k=5) | |
| top5_labels = [LABELS[i] for i in top5.indices.tolist()] | |
| top5_scores = [round(p, 4) for p in top5.values.tolist()] | |
| top5_text = "\n".join([f"{label}: {score}" for label, score in zip(top5_labels, top5_scores)]) | |
| return top_label, top_conf, top5_text | |
| def transcribe_audio(audio_path): | |
| result = whisper(audio_path, return_timestamps=True) | |
| return result.get("text", "").strip() | |
| # === Main Handler === | |
| def process_input(audio_file, video_file, video_url): | |
| try: | |
| audio_path = None | |
| if audio_file: | |
| shutil.copy(audio_file, RAW_AUDIO) | |
| audio_path = convert_to_wav(RAW_AUDIO) | |
| elif video_file: | |
| shutil.copy(video_file, TEMP_VIDEO) | |
| extracted = extract_audio_from_video(TEMP_VIDEO, output_path="extracted_audio.wav") | |
| audio_path = convert_to_wav(extracted) | |
| elif video_url and video_url.strip(): | |
| if "loom.com" in video_url: | |
| return "Loom links are not supported. Please upload the file or use a direct .mp4 URL.", None, None, None, None, None | |
| downloaded = download_video(video_url) | |
| extracted = extract_audio_from_video(downloaded, output_path="extracted_audio.wav") | |
| audio_path = convert_to_wav(extracted) | |
| else: | |
| return "Please provide an audio file, a video file, or a direct video URL.", None, None, None, None, None | |
| label, confidence, top5 = classify_accent(audio_path) | |
| transcription = transcribe_audio(audio_path) | |
| return f"Top prediction: {label}", confidence, label, audio_path, top5, transcription | |
| except Exception as e: | |
| return f"Error: {str(e)}", None, None, None, None, None | |
| finally: | |
| for f in [TEMP_VIDEO, RAW_AUDIO, CONVERTED_AUDIO, RAW_AUDIO + ".mp4"]: | |
| if os.path.exists(f): | |
| os.remove(f) | |
| # === Gradio Interface === | |
| interface = gr.Interface( | |
| fn=process_input, | |
| inputs=[ | |
| gr.Audio(label="Upload MP3 or WAV", type="filepath"), | |
| gr.File(label="Upload MP4 Video", type="filepath"), | |
| gr.Textbox(label="Paste Direct .mp4 Video URL") | |
| ], | |
| outputs=[ | |
| gr.Text(label="Prediction"), | |
| gr.Number(label="Confidence Score"), | |
| gr.Text(label="Accent"), | |
| gr.Audio(label="Processed Audio", type="filepath"), | |
| gr.Text(label="Top 5 Predictions"), | |
| gr.Text(label="Transcription") | |
| ], | |
| title="Accent Classifier + Transcriber", | |
| description="Upload an audio or video file OR paste a direct video URL to classify the accent and transcribe the speech." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |