File size: 5,843 Bytes
1100e65
249a3c0
 
b09f327
 
53bdf99
b09f327
af532e7
53bdf99
 
 
 
 
8af57a0
6575bf4
077c90e
170241f
249a3c0
 
 
a18a113
249a3c0
b09f327
249a3c0
 
17ca647
b09f327
 
8af57a0
 
 
 
 
 
 
 
 
b09f327
 
 
 
 
 
 
 
 
8af57a0
 
b09f327
8af57a0
 
 
b09f327
 
 
 
 
 
 
81f702f
8369f51
 
81f702f
 
 
8369f51
 
81f702f
 
 
 
 
 
8369f51
 
 
a18a113
81f702f
0cfb05e
 
 
 
 
249a3c0
81f702f
 
0cfb05e
81f702f
0cfb05e
 
 
256795b
 
b09f327
53bdf99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df42ab3
 
53bdf99
 
 
836768f
81f702f
836768f
 
 
53bdf99
 
 
 
 
 
b9e0aa5
53bdf99
836768f
53bdf99
 
 
 
8af57a0
53bdf99
 
 
 
6575bf4
81f702f
 
 
6575bf4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import tempfile    
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube

print("Script started")

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Whisper model and processor
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)

def download_audio_from_url(url):
    try:
        if "youtube.com" in url or "youtu.be" in url:
            print("Processing YouTube URL...")
            yt = YouTube(url)
            audio_stream = yt.streams.filter(only_audio=True).first()
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
                audio_stream.download(output_path=temp_file.name)
                audio_bytes = open(temp_file.name, "rb").read()
            os.unlink(temp_file.name)
        elif "share" in url:
            print("Processing shareable link...")
            response = requests.get(url)
            soup = BeautifulSoup(response.content, 'html.parser')
            video_tag = soup.find('video')
            if video_tag and 'src' in video_tag.attrs:
                video_url = video_tag['src']
                print(f"Extracted video URL: {video_url}")
            else:
                raise ValueError("Direct video URL not found in the shareable link.")
            response = requests.get(video_url)
            audio_bytes = response.content
        else:
            print(f"Downloading video from URL: {url}")
            response = requests.get(url)
            audio_bytes = response.content
        
        print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
        return audio_bytes
    except Exception as e:
        print(f"Error in download_audio_from_url: {str(e)}")
        raise

def transcribe_audio(audio_file):
    try:
        print("Loading audio file...")
        audio = AudioSegment.from_file(audio_file)
        audio = audio.set_channels(1).set_frame_rate(16000)
        audio_array = audio.get_array_of_samples()
        
        print("Starting transcription...")
        input_features = processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device)
        predicted_ids = model.generate(input_features)
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        
        print(f"Transcription complete. Length: {len(transcription[0])} characters")
        return transcription[0]
    except Exception as e:
        print(f"Error in transcribe_audio: {str(e)}")
        raise

def transcribe_video(url):
    try:
        print(f"Attempting to download audio from URL: {url}")
        audio_bytes = download_audio_from_url(url)
        print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
            AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
            transcript = transcribe_audio(temp_audio.name)
        
        os.unlink(temp_audio.name)
        return transcript
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        print(error_message)
        return error_message

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.H1("Video Transcription", className="text-center mb-4"),
            dbc.Card([
                dbc.CardBody([
                    dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
                    dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
                    dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
                    dcc.Download(id="download-transcript")
                ])
            ])
        ], width=12)
    ])
], fluid=True)

@app.callback(
    Output("transcription-output", "children"),
    Output("download-transcript", "data"),
    Input("transcribe-button", "n_clicks"),
    State("video-url", "value"),
    prevent_initial_call=True
)
def update_transcription(n_clicks, url):
    if not url:
        raise PreventUpdate

    def transcribe():
        try:
            transcript = transcribe_video(url)
            return transcript
        except Exception as e:
            return f"An error occurred: {str(e)}"

    # Run transcription in a separate thread
    thread = threading.Thread(target=transcribe)
    thread.start()
    thread.join()

    transcript = thread.result if hasattr(thread, 'result') else "Transcription failed"

    if transcript and not transcript.startswith("An error occurred"):
        download_data = dict(content=transcript, filename="transcript.txt")
        return dbc.Card([
            dbc.CardBody([
                html.H5("Transcription Result"),
                html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}),
                dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3")
            ])
        ]), download_data
    else:
        return transcript, None

print("Reached end of script definitions")

if __name__ == '__main__':
    print("Starting the Dash application...")
    app.run(debug=True, host='0.0.0.0', port=7860)
    print("Dash application has finished running.")