bluenevus's picture
Update app.py
da7b836 verified
raw
history blame
4.04 kB
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import requests
from bs4 import BeautifulSoup
import tempfile
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube
print("Script started")
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)
# Load the Qwen model and tokenizer
qwen_model_name = "Qwen/Qwen2.5-3B-Instruct"
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device)
# ... (keep all the existing functions as they are)
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app.layout = dbc.Container([
dbc.Row([
dbc.Col([
html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"),
dbc.Card([
dbc.CardBody([
dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
html.Div([
dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
dcc.Download(id="download-transcript")
])
])
])
], width=12)
])
], fluid=True)
@app.callback(
Output("transcription-output", "children"),
Output("download-button", "style"),
Input("transcribe-button", "n_clicks"),
State("video-url", "value"),
prevent_initial_call=True
)
def update_transcription(n_clicks, url):
if not url:
raise PreventUpdate
def transcribe():
try:
transcript = transcribe_video(url)
return transcript
except Exception as e:
import traceback
return f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
# Run transcription in a separate thread
thread = threading.Thread(target=transcribe)
thread.start()
thread.join(timeout=600) # 10 minutes timeout
if thread.is_alive():
return "Transcription timed out after 10 minutes", {'display': 'none'}
transcript = getattr(thread, 'result', "Transcription failed")
if transcript and not transcript.startswith("An error occurred"):
return dbc.Card([
dbc.CardBody([
html.H5("Transcription Result with Speaker Separation"),
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
])
]), {'display': 'block'}
else:
return transcript, {'display': 'none'}
@app.callback(
Output("download-transcript", "data"),
Input("download-button", "n_clicks"),
State("transcription-output", "children"),
prevent_initial_call=True
)
def download_transcript(n_clicks, transcription_output):
if not transcription_output:
raise PreventUpdate
transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
return dict(content=transcript, filename="transcript.txt")
if __name__ == '__main__':
print("Starting the Dash application...")
app.run_server(debug=True, host='0.0.0.0', port=7860)
print("Dash application has finished running.")