bluenevus's picture
Update app.py
b09f327 verified
raw
history blame
4.79 kB
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import tempfile
import os
import soundfile as sf
from spellchecker import SpellChecker
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Whisper model and processor
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
spell = SpellChecker()
def download_audio_from_url(url):
try:
if "share" in url:
print("Processing shareable link...")
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
video_tag = soup.find('video')
if video_tag and 'src' in video_tag.attrs:
video_url = video_tag['src']
print(f"Extracted video URL: {video_url}")
else:
raise ValueError("Direct video URL not found in the shareable link.")
else:
video_url = url
print(f"Downloading video from URL: {video_url}")
response = requests.get(video_url)
audio_bytes = response.content
print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
return audio_bytes
except Exception as e:
print(f"Error in download_audio_from_url: {str(e)}")
raise
def correct_spelling(text):
words = text.split()
corrected_words = [spell.correction(word) or word for word in words]
return ' '.join(corrected_words)
def format_transcript(transcript):
sentences = transcript.split('.')
formatted_transcript = []
current_speaker = None
for sentence in sentences:
if ':' in sentence:
speaker, content = sentence.split(':', 1)
if speaker != current_speaker:
formatted_transcript.append(f"\n\n{speaker.strip()}:{content.strip()}.")
current_speaker = speaker
else:
formatted_transcript.append(f"{content.strip()}.")
else:
formatted_transcript.append(sentence.strip() + '.')
return ' '.join(formatted_transcript)
def transcribe_audio(audio_file):
try:
# Load and preprocess the audio
audio_input, sample_rate = sf.read(audio_file)
input_features = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_features.to(device)
# Generate token ids
predicted_ids = model.generate(input_features)
# Decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
except Exception as e:
print(f"Error in transcribe_audio: {str(e)}")
raise
def transcribe_video(url):
try:
print(f"Attempting to download audio from URL: {url}")
audio_bytes = download_audio_from_url(url)
print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
# Save audio bytes to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio.write(audio_bytes)
temp_audio_path = temp_audio.name
print("Starting audio transcription...")
transcript = transcribe_audio(temp_audio_path)
print("Transcription completed successfully")
# Clean up the temporary file
os.unlink(temp_audio_path)
# Apply spelling correction and formatting
transcript = correct_spelling(transcript)
transcript = format_transcript(transcript)
return transcript
except Exception as e:
error_message = f"An error occurred: {str(e)}"
print(error_message)
return error_message
def download_transcript(transcript):
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as temp_file:
temp_file.write(transcript)
temp_file_path = temp_file.name
return temp_file_path
# Create the Gradio interface
with gr.Blocks(title="Video Transcription") as demo:
gr.Markdown("# Video Transcription")
video_url = gr.Textbox(label="Video URL")
transcribe_button = gr.Button("Transcribe")
transcript_output = gr.Textbox(label="Transcript", lines=20)
download_button = gr.Button("Download Transcript")
download_link = gr.File(label="Download Transcript")
transcribe_button.click(fn=transcribe_video, inputs=video_url, outputs=transcript_output)
download_button.click(fn=download_transcript, inputs=transcript_output, outputs=download_link)
demo.launch()