File size: 8,638 Bytes
1100e65
edaf161
249a3c0
 
b09f327
 
53bdf99
b09f327
 
 
af532e7
7a13c00
1b493d6
170241f
53bdf99
 
 
 
 
 
8af57a0
6575bf4
077c90e
170241f
249a3c0
 
 
a18a113
249a3c0
b09f327
249a3c0
 
17ca647
b09f327
 
 
 
8af57a0
 
 
 
 
 
 
 
 
b09f327
 
 
 
 
 
 
 
 
8af57a0
 
b09f327
8af57a0
 
 
b09f327
 
 
 
 
 
 
 
 
 
 
 
59a6a31
b09f327
 
59a6a31
 
 
 
 
 
 
 
 
 
 
 
b09f327
836768f
8369f51
b9e0aa5
 
 
8369f51
 
 
 
f1f904a
8369f51
8af57a0
 
 
532f762
8369f51
 
 
 
 
 
 
 
 
 
 
 
752e0a6
8369f51
 
f1f904a
8af57a0
 
8369f51
 
 
 
 
a18a113
836768f
0cfb05e
 
 
 
 
af532e7
 
 
256795b
 
af532e7
249a3c0
af532e7
249a3c0
 
0cfb05e
836768f
256795b
0cfb05e
249a3c0
 
 
59a6a31
b09f327
 
0cfb05e
 
 
256795b
 
b09f327
53bdf99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df42ab3
 
53bdf99
 
 
836768f
df42ab3
 
b9e0aa5
 
836768f
 
 
 
 
 
53bdf99
 
 
 
 
 
b9e0aa5
53bdf99
836768f
53bdf99
 
 
 
8af57a0
53bdf99
 
 
 
6575bf4
a216ced
6575bf4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import io
import re
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import tempfile    
import os
import soundfile as sf
from spellchecker import SpellChecker
from pydub import AudioSegment
import librosa
import numpy as np
from pyannote.audio import Pipeline
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import base64
import threading
from pytube import YouTube

print("Script started")

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Whisper model and processor
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)

spell = SpellChecker()

def download_audio_from_url(url):
    try:
        if "youtube.com" in url or "youtu.be" in url:
            print("Processing YouTube URL...")
            yt = YouTube(url)
            audio_stream = yt.streams.filter(only_audio=True).first()
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
                audio_stream.download(output_path=temp_file.name)
                audio_bytes = open(temp_file.name, "rb").read()
            os.unlink(temp_file.name)
        elif "share" in url:
            print("Processing shareable link...")
            response = requests.get(url)
            soup = BeautifulSoup(response.content, 'html.parser')
            video_tag = soup.find('video')
            if video_tag and 'src' in video_tag.attrs:
                video_url = video_tag['src']
                print(f"Extracted video URL: {video_url}")
            else:
                raise ValueError("Direct video URL not found in the shareable link.")
            response = requests.get(video_url)
            audio_bytes = response.content
        else:
            print(f"Downloading video from URL: {url}")
            response = requests.get(url)
            audio_bytes = response.content
        
        print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
        return audio_bytes
    except Exception as e:
        print(f"Error in download_audio_from_url: {str(e)}")
        raise

def correct_spelling(text):
    words = text.split()
    corrected_words = [spell.correction(word) or word for word in words]
    return ' '.join(corrected_words)

def format_transcript_with_speakers(transcript, diarization):
    formatted_transcript = []
    current_speaker = None
    for segment, _, speaker in diarization.itertracks(yield_label=True):
        start = segment.start
        end = segment.end
        if speaker != current_speaker:
            if current_speaker is not None:
                formatted_transcript.append("\n")  # Add a blank line between speakers
            formatted_transcript.append(f"Speaker {speaker}:\n")
            current_speaker = speaker
        segment_text = transcript[start:end].strip()
        if segment_text:
            formatted_transcript.append(f"{segment_text}\n")
    return "".join(formatted_transcript)

def transcribe_audio(audio_file, pipeline):
    try:
        if pipeline is None:
            raise ValueError("Speaker diarization pipeline is not initialized")

        print("Loading audio file...")
        audio_input, sr = librosa.load(audio_file, sr=16000)
        audio_input = audio_input.astype(np.float32)
        print(f"Audio duration: {len(audio_input) / sr:.2f} seconds")

        # Apply speaker diarization
        print("Applying speaker diarization...")
        diarization = pipeline(audio_file)
        print("Speaker diarization complete.")

        chunk_length = 30 * sr
        overlap = 5 * sr
        transcriptions = []
        
        print("Starting transcription...")
        for i in range(0, len(audio_input), chunk_length - overlap):
            chunk = audio_input[i:i+chunk_length]
            input_features = processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
            predicted_ids = model.generate(input_features)
            transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
            transcriptions.extend(transcription)
            print(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")

        full_transcription = " ".join(transcriptions)
        print(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")

        print("Applying formatting with speaker diarization...")
        formatted_transcription = format_transcript_with_speakers(full_transcription, diarization)

        return formatted_transcription
    except Exception as e:
        print(f"Error in transcribe_audio: {str(e)}")
        raise

def transcribe_video(url, pipeline):
    try:
        print(f"Attempting to download audio from URL: {url}")
        audio_bytes = download_audio_from_url(url)
        print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
        
        # Convert audio bytes to AudioSegment
        audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
        
        print(f"Audio duration: {len(audio) / 1000} seconds")
        
        # Save as WAV file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
            audio.export(temp_audio.name, format="wav")
            temp_audio_path = temp_audio.name

        print("Starting audio transcription...")
        transcript = transcribe_audio(temp_audio_path, pipeline)
        print(f"Transcription completed. Transcript length: {len(transcript)} characters")
        
        # Clean up the temporary file
        os.unlink(temp_audio_path)

        # Apply spelling correction
        transcript = correct_spelling(transcript)

        return transcript
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        print(error_message)
        return error_message

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.H1("Video Transcription", className="text-center mb-4"),
            dbc.Card([
                dbc.CardBody([
                    dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
                    dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
                    dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
                    dcc.Download(id="download-transcript")
                ])
            ])
        ], width=12)
    ])
], fluid=True)

@app.callback(
    Output("transcription-output", "children"),
    Output("download-transcript", "data"),
    Input("transcribe-button", "n_clicks"),
    State("video-url", "value"),
    prevent_initial_call=True
)
def update_transcription(n_clicks, url):
    if not url:
        raise PreventUpdate

    def transcribe():
        try:
            # Initialize the speaker diarization pipeline without token
            pipeline = Pipeline.from_pretrained("collinbarnwell/pyannote-speaker-diarization-31")
            if pipeline is None:
                raise ValueError("Failed to initialize the speaker diarization pipeline")
            print("Speaker diarization pipeline initialized successfully")

            transcript = transcribe_video(url, pipeline)
            return transcript
        except Exception as e:
            return f"An error occurred: {str(e)}"

    # Run transcription in a separate thread
    thread = threading.Thread(target=transcribe)
    thread.start()
    thread.join()

    transcript = thread.result if hasattr(thread, 'result') else "Transcription failed"

    if transcript and not transcript.startswith("An error occurred"):
        download_data = dict(content=transcript, filename="transcript.txt")
        return dbc.Card([
            dbc.CardBody([
                html.H5("Transcription Result"),
                html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}),
                dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3")
            ])
        ]), download_data
    else:
        return transcript, None
        
if __name__ == '__main__':
    print("Starting the Dash application...")
    app.run(debug=True, host='0.0.0.0', port=7860)
    print("Dash application has finished running.")