File size: 7,980 Bytes
1100e65
249a3c0
c982392
b09f327
 
53bdf99
b09f327
af532e7
53bdf99
 
 
 
 
8af57a0
5ee7955
a62f407
13f2ba5
5ee7955
 
c982392
5ee7955
6575bf4
077c90e
170241f
249a3c0
 
 
a18a113
249a3c0
261d49a
fce37ea
26cf8bb
fce37ea
9a1f744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd5e97
a80c887
 
 
 
 
 
 
9a1f744
ffd5e97
a80c887
 
 
 
ffd5e97
a80c887
 
 
 
 
 
c982392
9a1f744
ffd5e97
9a1f744
 
 
 
aed95f4
9a1f744
aed95f4
9a1f744
 
 
 
 
 
 
 
 
 
aa497bf
 
 
c982392
9a1f744
aa497bf
aed95f4
9a1f744
b09f327
53bdf99
 
 
 
 
c982392
 
53bdf99
 
 
 
 
26cf8bb
 
 
 
53bdf99
 
 
 
 
 
 
 
26cf8bb
53bdf99
 
 
 
df42ab3
 
53bdf99
 
c982392
 
 
aa497bf
c982392
 
 
 
 
 
 
 
 
 
 
aa497bf
c982392
 
 
aa497bf
26cf8bb
aed95f4
aa497bf
aed95f4
 
c982392
aed95f4
 
 
 
aa497bf
aed95f4
53bdf99
26cf8bb
 
 
 
 
 
 
 
 
 
 
 
81f702f
6575bf4
5ee7955
31b9df5
5ee7955
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import tempfile    
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube
import logging
import librosa
import numpy as np

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("Script started")

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)

def download_audio_from_url(url):
    try:
        if "youtube.com" in url or "youtu.be" in url:
            print("Processing YouTube URL...")
            yt = YouTube(url)
            audio_stream = yt.streams.filter(only_audio=True).first()
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
                audio_stream.download(output_path=temp_file.name)
                audio_bytes = open(temp_file.name, "rb").read()
            os.unlink(temp_file.name)
        elif "share" in url:
            print("Processing shareable link...")
            response = requests.get(url)
            soup = BeautifulSoup(response.content, 'html.parser')
            video_tag = soup.find('video')
            if video_tag and 'src' in video_tag.attrs:
                video_url = video_tag['src']
                print(f"Extracted video URL: {video_url}")
            else:
                raise ValueError("Direct video URL not found in the shareable link.")
            response = requests.get(video_url)
            audio_bytes = response.content
        else:
            print(f"Downloading video from URL: {url}")
            response = requests.get(url)
            audio_bytes = response.content
        
        print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
        return audio_bytes
    except Exception as e:
        print(f"Error in download_audio_from_url: {str(e)}")
        raise

def transcribe_audio(audio_file):
    try:
        logger.info("Loading audio file...")
        audio_input, sr = librosa.load(audio_file, sr=16000)
        audio_input = audio_input.astype(np.float32)
        logger.info(f"Audio duration: {len(audio_input) / sr:.2f} seconds")

        chunk_length = 30 * sr
        overlap = 5 * sr
        transcriptions = []
        
        logger.info("Starting transcription...")
        for i in range(0, len(audio_input), chunk_length - overlap):
            chunk = audio_input[i:i+chunk_length]
            input_features = whisper_processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
            predicted_ids = whisper_model.generate(input_features)
            transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
            transcriptions.extend(transcription)
            logger.info(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")

        full_transcription = " ".join(transcriptions)
        logger.info(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")

        return full_transcription
    except Exception as e:
        logger.error(f"Error in transcribe_audio: {str(e)}")
        raise

def transcribe_video(url):
    try:
        logger.info(f"Attempting to download audio from URL: {url}")
        audio_bytes = download_audio_from_url(url)
        logger.info(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
            AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
            transcript = transcribe_audio(temp_audio.name)
        
        os.unlink(temp_audio.name)
        
        if len(transcript) < 10:
            raise ValueError("Transcription too short, possibly failed")
        
        logger.info(f"Transcription successful. Length: {len(transcript)} characters")
        logger.info(f"First 100 characters of transcript: {transcript[:100]}...")
        
        return transcript
    except Exception as e:
        error_message = f"An error occurred in transcribe_video: {str(e)}"
        logger.error(error_message)
        return error_message

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.H1("Video Transcription", className="text-center mb-4"),
            html.Div("If you can see this, the app is working!", className="text-center mb-4"),
            dbc.Card([
                dbc.CardBody([
                    dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
                    dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
                    dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
                    html.Div([
                        dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
                        dcc.Download(id="download-transcript")
                    ])
                ])
            ])
        ], width=12)
    ])
], fluid=True)

@app.callback(
    Output("transcription-output", "children"),
    Output("download-button", "style"),
    Input("transcribe-button", "n_clicks"),
    State("video-url", "value"),
    prevent_initial_call=True
)
def update_transcription(n_clicks, url):
    if not url:
        raise PreventUpdate

    def transcribe():
        try:
            transcript = transcribe_video(url)
            logger.info(f"Transcription completed. Result length: {len(transcript)} characters")
            return transcript
        except Exception as e:
            logger.exception("Error in transcription:")
            return f"An error occurred: {str(e)}"

    # Run transcription in a separate thread
    thread = threading.Thread(target=transcribe)
    thread.start()
    thread.join(timeout=600)  # 10 minutes timeout

    if thread.is_alive():
        logger.warning("Transcription timed out after 10 minutes")
        return "Transcription timed out after 10 minutes", {'display': 'none'}

    transcript = getattr(thread, 'result', "Transcription failed")
    logger.info(f"Final transcript length: {len(transcript)} characters")

    if transcript and not transcript.startswith("An error occurred"):
        logger.info("Transcription successful, returning result")
        return dbc.Card([
            dbc.CardBody([
                html.H5("Transcription Result"),
                html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
            ])
        ]), {'display': 'block'}
    else:
        logger.error(f"Transcription failed: {transcript}")
        return transcript, {'display': 'none'}

@app.callback(
    Output("download-transcript", "data"),
    Input("download-button", "n_clicks"),
    State("transcription-output", "children"),
    prevent_initial_call=True
)
def download_transcript(n_clicks, transcription_output):
    if not transcription_output:
        raise PreventUpdate
    
    transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
    return dict(content=transcript, filename="transcript.txt")

if __name__ == '__main__':
    logger.info("Starting the Dash application...")
    app.run(debug=True, host='0.0.0.0', port=7860)
    logger.info("Dash application has finished running.")