File size: 4,035 Bytes
1100e65
249a3c0
fce37ea
b09f327
 
53bdf99
b09f327
af532e7
53bdf99
 
 
 
 
8af57a0
6575bf4
077c90e
170241f
249a3c0
 
 
a18a113
249a3c0
261d49a
fce37ea
26cf8bb
fce37ea
 
 
26cf8bb
54c226c
17ca647
da7b836
b09f327
53bdf99
 
 
 
 
fce37ea
53bdf99
 
 
 
 
26cf8bb
 
 
 
53bdf99
 
 
 
 
 
 
 
26cf8bb
53bdf99
 
 
 
df42ab3
 
53bdf99
 
 
836768f
81f702f
836768f
 
26cf8bb
 
53bdf99
 
 
 
26cf8bb
 
 
 
53bdf99
da7b836
53bdf99
836768f
53bdf99
 
fce37ea
26cf8bb
53bdf99
26cf8bb
53bdf99
26cf8bb
81f702f
26cf8bb
 
 
 
 
 
 
 
 
 
 
 
81f702f
6575bf4
 
da7b836
6575bf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import requests
from bs4 import BeautifulSoup
import tempfile    
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube

print("Script started")

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)

# Load the Qwen model and tokenizer
qwen_model_name = "Qwen/Qwen2.5-3B-Instruct"
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device)

# ... (keep all the existing functions as they are)

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"),
            dbc.Card([
                dbc.CardBody([
                    dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
                    dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
                    dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
                    html.Div([
                        dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
                        dcc.Download(id="download-transcript")
                    ])
                ])
            ])
        ], width=12)
    ])
], fluid=True)

@app.callback(
    Output("transcription-output", "children"),
    Output("download-button", "style"),
    Input("transcribe-button", "n_clicks"),
    State("video-url", "value"),
    prevent_initial_call=True
)
def update_transcription(n_clicks, url):
    if not url:
        raise PreventUpdate

    def transcribe():
        try:
            transcript = transcribe_video(url)
            return transcript
        except Exception as e:
            import traceback
            return f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"

    # Run transcription in a separate thread
    thread = threading.Thread(target=transcribe)
    thread.start()
    thread.join(timeout=600)  # 10 minutes timeout

    if thread.is_alive():
        return "Transcription timed out after 10 minutes", {'display': 'none'}

    transcript = getattr(thread, 'result', "Transcription failed")

    if transcript and not transcript.startswith("An error occurred"):
        return dbc.Card([
            dbc.CardBody([
                html.H5("Transcription Result with Speaker Separation"),
                html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
            ])
        ]), {'display': 'block'}
    else:
        return transcript, {'display': 'none'}

@app.callback(
    Output("download-transcript", "data"),
    Input("download-button", "n_clicks"),
    State("transcription-output", "children"),
    prevent_initial_call=True
)
def download_transcript(n_clicks, transcription_output):
    if not transcription_output:
        raise PreventUpdate
    
    transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
    return dict(content=transcript, filename="transcript.txt")

if __name__ == '__main__':
    print("Starting the Dash application...")
    app.run_server(debug=True, host='0.0.0.0', port=7860)
    print("Dash application has finished running.")