File size: 8,532 Bytes
1100e65
249a3c0
fce37ea
b09f327
 
53bdf99
b09f327
af532e7
53bdf99
 
 
 
 
8af57a0
6575bf4
077c90e
170241f
249a3c0
 
 
a18a113
249a3c0
fd1e8cf
fce37ea
26cf8bb
fce37ea
 
 
26cf8bb
60d3e8d
17ca647
b09f327
 
8af57a0
 
 
 
 
 
 
 
 
b09f327
 
 
 
 
 
 
 
 
8af57a0
 
b09f327
8af57a0
 
 
b09f327
 
 
 
 
 
 
81f702f
8369f51
 
81f702f
 
5adda7d
8369f51
26cf8bb
8369f51
26cf8bb
 
fce37ea
81f702f
 
 
8369f51
 
 
a18a113
fce37ea
26cf8bb
fce37ea
 
60d3e8d
fce37ea
 
 
 
 
 
 
 
 
 
60d3e8d
26cf8bb
 
fce37ea
 
 
 
 
26cf8bb
fce37ea
60d3e8d
81f702f
0cfb05e
 
 
 
 
249a3c0
81f702f
 
0cfb05e
81f702f
fce37ea
 
 
 
 
0cfb05e
 
256795b
 
b09f327
53bdf99
 
 
 
 
fce37ea
53bdf99
 
 
 
 
26cf8bb
 
 
 
53bdf99
 
 
 
 
 
 
 
26cf8bb
53bdf99
 
 
 
df42ab3
 
53bdf99
 
 
836768f
81f702f
836768f
 
26cf8bb
 
53bdf99
 
 
 
26cf8bb
 
 
 
53bdf99
b9e0aa5
53bdf99
836768f
53bdf99
 
fce37ea
26cf8bb
53bdf99
26cf8bb
53bdf99
26cf8bb
81f702f
26cf8bb
 
 
 
 
 
 
 
 
 
 
 
81f702f
6575bf4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import requests
from bs4 import BeautifulSoup
import tempfile    
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube

print("Script started")

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small.en"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)

# Load the Qwen model and tokenizer
qwen_model_name = "Qwen/Qwen2.5-3B-Instruct"
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True, torch_dtype=torch.float16).to(device)

def download_audio_from_url(url):
    try:
        if "youtube.com" in url or "youtu.be" in url:
            print("Processing YouTube URL...")
            yt = YouTube(url)
            audio_stream = yt.streams.filter(only_audio=True).first()
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
                audio_stream.download(output_path=temp_file.name)
                audio_bytes = open(temp_file.name, "rb").read()
            os.unlink(temp_file.name)
        elif "share" in url:
            print("Processing shareable link...")
            response = requests.get(url)
            soup = BeautifulSoup(response.content, 'html.parser')
            video_tag = soup.find('video')
            if video_tag and 'src' in video_tag.attrs:
                video_url = video_tag['src']
                print(f"Extracted video URL: {video_url}")
            else:
                raise ValueError("Direct video URL not found in the shareable link.")
            response = requests.get(video_url)
            audio_bytes = response.content
        else:
            print(f"Downloading video from URL: {url}")
            response = requests.get(url)
            audio_bytes = response.content
        
        print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
        return audio_bytes
    except Exception as e:
        print(f"Error in download_audio_from_url: {str(e)}")
        raise

def transcribe_audio(audio_file):
    try:
        print("Loading audio file...")
        audio = AudioSegment.from_file(audio_file)
        audio = audio.set_channels(1).set_frame_rate(16000)
        audio_array = torch.tensor(audio.get_array_of_samples()).float()
        
        print(f"Audio duration: {len(audio) / 1000:.2f} seconds")
        print("Starting transcription...")
        input_features = whisper_processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device)
        predicted_ids = whisper_model.generate(input_features)
        transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
        
        print(f"Transcription complete. Length: {len(transcription[0])} characters")
        return transcription[0]
    except Exception as e:
        print(f"Error in transcribe_audio: {str(e)}")
        raise

def separate_speakers(transcription):
    print("Starting speaker separation...")
    prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows:

1. Label speakers as "Speaker 1", "Speaker 2", etc.  You will have to use dialog context to asume which speaker is saying their dialog as that isn't in the text.
2. Start each speaker's text on a new line beginning with their label.
3. Separate different speakers' contributions with a blank line.
4. If the same speaker continues, do not insert a blank line or repeat the speaker label.

Now, please process the following transcribed text:

{transcription}
"""
    
    inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
    inputs = {k: v.to(torch.float16) for k, v in inputs.items()}  # Convert inputs to float16
    with torch.no_grad():
        outputs = qwen_model.generate(**inputs, max_new_tokens=4000)
    result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the processed text (remove the instruction part)
    processed_text = result.split("Now, please process the following transcribed text:")[-1].strip()
    
    print("Speaker separation complete.")
    return processed_text
    
def transcribe_video(url):
    try:
        print(f"Attempting to download audio from URL: {url}")
        audio_bytes = download_audio_from_url(url)
        print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
            AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
            transcript = transcribe_audio(temp_audio.name)
        
        os.unlink(temp_audio.name)
        
        print("Separating speakers...")
        separated_transcript = separate_speakers(transcript)
        
        return separated_transcript
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        print(error_message)
        return error_message

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"),
            dbc.Card([
                dbc.CardBody([
                    dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
                    dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
                    dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
                    html.Div([
                        dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
                        dcc.Download(id="download-transcript")
                    ])
                ])
            ])
        ], width=12)
    ])
], fluid=True)

@app.callback(
    Output("transcription-output", "children"),
    Output("download-button", "style"),
    Input("transcribe-button", "n_clicks"),
    State("video-url", "value"),
    prevent_initial_call=True
)
def update_transcription(n_clicks, url):
    if not url:
        raise PreventUpdate

    def transcribe():
        try:
            transcript = transcribe_video(url)
            return transcript
        except Exception as e:
            import traceback
            return f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"

    # Run transcription in a separate thread
    thread = threading.Thread(target=transcribe)
    thread.start()
    thread.join(timeout=600)  # 10 minutes timeout

    if thread.is_alive():
        return "Transcription timed out after 10 minutes", {'display': 'none'}

    transcript = thread.result if hasattr(thread, 'result') else "Transcription failed"

    if transcript and not transcript.startswith("An error occurred"):
        return dbc.Card([
            dbc.CardBody([
                html.H5("Transcription Result with Speaker Separation"),
                html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
            ])
        ]), {'display': 'block'}
    else:
        return transcript, {'display': 'none'}

@app.callback(
    Output("download-transcript", "data"),
    Input("download-button", "n_clicks"),
    State("transcription-output", "children"),
    prevent_initial_call=True
)
def download_transcript(n_clicks, transcription_output):
    if not transcription_output:
        raise PreventUpdate
    
    transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
    return dict(content=transcript, filename="transcript.txt")

if __name__ == '__main__':
    print("Starting the Dash application...")
    app.run(debug=True, host='0.0.0.0', port=7860)
    print("Dash application has finished running.")