File size: 9,151 Bytes
1100e65
249a3c0
fce37ea
b09f327
 
53bdf99
b09f327
af532e7
53bdf99
 
 
 
 
8af57a0
5ee7955
a62f407
5ee7955
 
 
 
6575bf4
077c90e
170241f
249a3c0
 
 
a18a113
249a3c0
261d49a
fce37ea
26cf8bb
fce37ea
 
1df2592
26cf8bb
54c226c
17ca647
9a1f744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd5e97
a80c887
 
 
 
 
 
 
9a1f744
ffd5e97
a80c887
 
 
 
ffd5e97
a80c887
 
 
 
 
 
 
 
 
 
9a1f744
ffd5e97
9a1f744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b09f327
53bdf99
 
 
 
 
fce37ea
5ee7955
53bdf99
 
 
 
 
26cf8bb
 
 
 
53bdf99
 
 
 
 
 
 
 
26cf8bb
53bdf99
 
 
 
df42ab3
 
53bdf99
 
 
836768f
81f702f
836768f
 
5ee7955
 
53bdf99
 
 
 
26cf8bb
 
 
 
53bdf99
da7b836
53bdf99
836768f
53bdf99
 
fce37ea
26cf8bb
53bdf99
26cf8bb
53bdf99
26cf8bb
81f702f
26cf8bb
 
 
 
 
 
 
 
 
 
 
 
81f702f
6575bf4
5ee7955
31b9df5
5ee7955
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import requests
from bs4 import BeautifulSoup
import tempfile    
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube
import logging
import librosa

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

print("Script started")

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)

# Load the Qwen model and tokenizer
qwen_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device)

def download_audio_from_url(url):
    try:
        if "youtube.com" in url or "youtu.be" in url:
            print("Processing YouTube URL...")
            yt = YouTube(url)
            audio_stream = yt.streams.filter(only_audio=True).first()
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
                audio_stream.download(output_path=temp_file.name)
                audio_bytes = open(temp_file.name, "rb").read()
            os.unlink(temp_file.name)
        elif "share" in url:
            print("Processing shareable link...")
            response = requests.get(url)
            soup = BeautifulSoup(response.content, 'html.parser')
            video_tag = soup.find('video')
            if video_tag and 'src' in video_tag.attrs:
                video_url = video_tag['src']
                print(f"Extracted video URL: {video_url}")
            else:
                raise ValueError("Direct video URL not found in the shareable link.")
            response = requests.get(video_url)
            audio_bytes = response.content
        else:
            print(f"Downloading video from URL: {url}")
            response = requests.get(url)
            audio_bytes = response.content
        
        print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
        return audio_bytes
    except Exception as e:
        print(f"Error in download_audio_from_url: {str(e)}")
        raise

def transcribe_audio(audio_file):
    try:
        logger.info("Loading audio file...")
        audio_input, sr = librosa.load(audio_file, sr=16000)
        audio_input = audio_input.astype(np.float32)
        logger.info(f"Audio duration: {len(audio_input) / sr:.2f} seconds")

        chunk_length = 30 * sr
        overlap = 5 * sr
        transcriptions = []
        
        logger.info("Starting transcription...")
        for i in range(0, len(audio_input), chunk_length - overlap):
            chunk = audio_input[i:i+chunk_length]
            input_features = whisper_processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
            predicted_ids = whisper_model.generate(input_features)
            transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
            transcriptions.extend(transcription)
            logger.info(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")

        full_transcription = " ".join(transcriptions)
        logger.info(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")

        logger.info("Applying speaker separation using Qwen...")
        separated_transcript = separate_speakers(full_transcription)

        return separated_transcript
    except Exception as e:
        logger.error(f"Error in transcribe_audio: {str(e)}")
        raise

def separate_speakers(transcription):
    print("Starting speaker separation...")
    prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows:

1. Label speakers as "Speaker 1", "Speaker 2", etc.
2. Start each speaker's text on a new line beginning with their label.
3. Separate different speakers' contributions with a blank line.
4. If the same speaker continues, do not insert a blank line or repeat the speaker label.

Now, please process the following transcribed text:

{transcription}
"""
    
    inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = qwen_model.generate(**inputs, max_new_tokens=4000)
    result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the processed text (remove the instruction part)
    processed_text = result.split("Now, please process the following transcribed text:")[-1].strip()
    
    print("Speaker separation complete.")
    return processed_text

def transcribe_video(url):
    try:
        print(f"Attempting to download audio from URL: {url}")
        audio_bytes = download_audio_from_url(url)
        print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
            AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
            transcript = transcribe_audio(temp_audio.name)
        
        os.unlink(temp_audio.name)
        
        if len(transcript) < 10:
            raise ValueError("Transcription too short, possibly failed")
        
        print("Separating speakers...")
        separated_transcript = separate_speakers(transcript)
        
        return separated_transcript
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        print(error_message)
        return error_message

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"),
            html.Div("If you can see this, the app is working!", className="text-center mb-4"),  # Debug element
            dbc.Card([
                dbc.CardBody([
                    dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
                    dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
                    dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
                    html.Div([
                        dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
                        dcc.Download(id="download-transcript")
                    ])
                ])
            ])
        ], width=12)
    ])
], fluid=True)

@app.callback(
    Output("transcription-output", "children"),
    Output("download-button", "style"),
    Input("transcribe-button", "n_clicks"),
    State("video-url", "value"),
    prevent_initial_call=True
)
def update_transcription(n_clicks, url):
    if not url:
        raise PreventUpdate

    def transcribe():
        try:
            transcript = transcribe_video(url)
            return transcript
        except Exception as e:
            logger.exception("Error in transcription:")
            return f"An error occurred: {str(e)}"

    # Run transcription in a separate thread
    thread = threading.Thread(target=transcribe)
    thread.start()
    thread.join(timeout=600)  # 10 minutes timeout

    if thread.is_alive():
        return "Transcription timed out after 10 minutes", {'display': 'none'}

    transcript = getattr(thread, 'result', "Transcription failed")

    if transcript and not transcript.startswith("An error occurred"):
        return dbc.Card([
            dbc.CardBody([
                html.H5("Transcription Result with Speaker Separation"),
                html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
            ])
        ]), {'display': 'block'}
    else:
        return transcript, {'display': 'none'}

@app.callback(
    Output("download-transcript", "data"),
    Input("download-button", "n_clicks"),
    State("transcription-output", "children"),
    prevent_initial_call=True
)
def download_transcript(n_clicks, transcription_output):
    if not transcription_output:
        raise PreventUpdate
    
    transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
    return dict(content=transcript, filename="transcript.txt")

if __name__ == '__main__':
    logger.info("Starting the Dash application...")
    app.run(debug=True, host='0.0.0.0', port=7860)
    logger.info("Dash application has finished running.")