|
import io |
|
import torch |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM |
|
import requests |
|
from bs4 import BeautifulSoup |
|
import tempfile |
|
import os |
|
from pydub import AudioSegment |
|
import dash |
|
from dash import dcc, html, Input, Output, State |
|
import dash_bootstrap_components as dbc |
|
from dash.exceptions import PreventUpdate |
|
import threading |
|
from pytube import YouTube |
|
|
|
print("Script started") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
|
|
whisper_model_name = "openai/whisper-small" |
|
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name) |
|
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device) |
|
|
|
|
|
qwen_model_name = "Qwen/Qwen2.5-3B-Instruct" |
|
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True) |
|
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device) |
|
|
|
|
|
|
|
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) |
|
|
|
app.layout = dbc.Container([ |
|
dbc.Row([ |
|
dbc.Col([ |
|
html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"), |
|
dbc.Card([ |
|
dbc.CardBody([ |
|
dbc.Input(id="video-url", type="text", placeholder="Enter video URL"), |
|
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"), |
|
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")), |
|
html.Div([ |
|
dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}), |
|
dcc.Download(id="download-transcript") |
|
]) |
|
]) |
|
]) |
|
], width=12) |
|
]) |
|
], fluid=True) |
|
|
|
@app.callback( |
|
Output("transcription-output", "children"), |
|
Output("download-button", "style"), |
|
Input("transcribe-button", "n_clicks"), |
|
State("video-url", "value"), |
|
prevent_initial_call=True |
|
) |
|
def update_transcription(n_clicks, url): |
|
if not url: |
|
raise PreventUpdate |
|
|
|
def transcribe(): |
|
try: |
|
transcript = transcribe_video(url) |
|
return transcript |
|
except Exception as e: |
|
import traceback |
|
return f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
|
|
|
thread = threading.Thread(target=transcribe) |
|
thread.start() |
|
thread.join(timeout=600) |
|
|
|
if thread.is_alive(): |
|
return "Transcription timed out after 10 minutes", {'display': 'none'} |
|
|
|
transcript = getattr(thread, 'result', "Transcription failed") |
|
|
|
if transcript and not transcript.startswith("An error occurred"): |
|
return dbc.Card([ |
|
dbc.CardBody([ |
|
html.H5("Transcription Result with Speaker Separation"), |
|
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}) |
|
]) |
|
]), {'display': 'block'} |
|
else: |
|
return transcript, {'display': 'none'} |
|
|
|
@app.callback( |
|
Output("download-transcript", "data"), |
|
Input("download-button", "n_clicks"), |
|
State("transcription-output", "children"), |
|
prevent_initial_call=True |
|
) |
|
def download_transcript(n_clicks, transcription_output): |
|
if not transcription_output: |
|
raise PreventUpdate |
|
|
|
transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children'] |
|
return dict(content=transcript, filename="transcript.txt") |
|
|
|
if __name__ == '__main__': |
|
print("Starting the Dash application...") |
|
app.run_server(debug=True, host='0.0.0.0', port=7860) |
|
print("Dash application has finished running.") |