bluenevus's picture
Update app.py
aa497bf verified
raw
history blame
7.98 kB
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import tempfile
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube
import logging
import librosa
import numpy as np
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
print("Script started")
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)
def download_audio_from_url(url):
try:
if "youtube.com" in url or "youtu.be" in url:
print("Processing YouTube URL...")
yt = YouTube(url)
audio_stream = yt.streams.filter(only_audio=True).first()
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
audio_stream.download(output_path=temp_file.name)
audio_bytes = open(temp_file.name, "rb").read()
os.unlink(temp_file.name)
elif "share" in url:
print("Processing shareable link...")
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
video_tag = soup.find('video')
if video_tag and 'src' in video_tag.attrs:
video_url = video_tag['src']
print(f"Extracted video URL: {video_url}")
else:
raise ValueError("Direct video URL not found in the shareable link.")
response = requests.get(video_url)
audio_bytes = response.content
else:
print(f"Downloading video from URL: {url}")
response = requests.get(url)
audio_bytes = response.content
print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
return audio_bytes
except Exception as e:
print(f"Error in download_audio_from_url: {str(e)}")
raise
def transcribe_audio(audio_file):
try:
logger.info("Loading audio file...")
audio_input, sr = librosa.load(audio_file, sr=16000)
audio_input = audio_input.astype(np.float32)
logger.info(f"Audio duration: {len(audio_input) / sr:.2f} seconds")
chunk_length = 30 * sr
overlap = 5 * sr
transcriptions = []
logger.info("Starting transcription...")
for i in range(0, len(audio_input), chunk_length - overlap):
chunk = audio_input[i:i+chunk_length]
input_features = whisper_processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
predicted_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
transcriptions.extend(transcription)
logger.info(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")
full_transcription = " ".join(transcriptions)
logger.info(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")
return full_transcription
except Exception as e:
logger.error(f"Error in transcribe_audio: {str(e)}")
raise
def transcribe_video(url):
try:
logger.info(f"Attempting to download audio from URL: {url}")
audio_bytes = download_audio_from_url(url)
logger.info(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
transcript = transcribe_audio(temp_audio.name)
os.unlink(temp_audio.name)
if len(transcript) < 10:
raise ValueError("Transcription too short, possibly failed")
logger.info(f"Transcription successful. Length: {len(transcript)} characters")
logger.info(f"First 100 characters of transcript: {transcript[:100]}...")
return transcript
except Exception as e:
error_message = f"An error occurred in transcribe_video: {str(e)}"
logger.error(error_message)
return error_message
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app.layout = dbc.Container([
dbc.Row([
dbc.Col([
html.H1("Video Transcription", className="text-center mb-4"),
html.Div("If you can see this, the app is working!", className="text-center mb-4"),
dbc.Card([
dbc.CardBody([
dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
html.Div([
dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
dcc.Download(id="download-transcript")
])
])
])
], width=12)
])
], fluid=True)
@app.callback(
Output("transcription-output", "children"),
Output("download-button", "style"),
Input("transcribe-button", "n_clicks"),
State("video-url", "value"),
prevent_initial_call=True
)
def update_transcription(n_clicks, url):
if not url:
raise PreventUpdate
def transcribe():
try:
transcript = transcribe_video(url)
logger.info(f"Transcription completed. Result length: {len(transcript)} characters")
return transcript
except Exception as e:
logger.exception("Error in transcription:")
return f"An error occurred: {str(e)}"
# Run transcription in a separate thread
thread = threading.Thread(target=transcribe)
thread.start()
thread.join(timeout=600) # 10 minutes timeout
if thread.is_alive():
logger.warning("Transcription timed out after 10 minutes")
return "Transcription timed out after 10 minutes", {'display': 'none'}
transcript = getattr(thread, 'result', "Transcription failed")
logger.info(f"Final transcript length: {len(transcript)} characters")
if transcript and not transcript.startswith("An error occurred"):
logger.info("Transcription successful, returning result")
return dbc.Card([
dbc.CardBody([
html.H5("Transcription Result"),
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
])
]), {'display': 'block'}
else:
logger.error(f"Transcription failed: {transcript}")
return transcript, {'display': 'none'}
@app.callback(
Output("download-transcript", "data"),
Input("download-button", "n_clicks"),
State("transcription-output", "children"),
prevent_initial_call=True
)
def download_transcript(n_clicks, transcription_output):
if not transcription_output:
raise PreventUpdate
transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
return dict(content=transcript, filename="transcript.txt")
if __name__ == '__main__':
logger.info("Starting the Dash application...")
app.run(debug=True, host='0.0.0.0', port=7860)
logger.info("Dash application has finished running.")