bluenevus commited on
Commit
639051f
·
verified ·
1 Parent(s): aa497bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -137
app.py CHANGED
@@ -1,177 +1,130 @@
1
  import io
2
- import torch
3
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
4
- import requests
5
- from bs4 import BeautifulSoup
6
- import tempfile
7
  import os
8
- from pydub import AudioSegment
 
 
 
 
9
  import dash
10
  from dash import dcc, html, Input, Output, State
11
  import dash_bootstrap_components as dbc
12
  from dash.exceptions import PreventUpdate
13
- import threading
14
- from pytube import YouTube
15
- import logging
16
- import librosa
17
- import numpy as np
18
-
19
- # Set up logging
20
- logging.basicConfig(level=logging.INFO)
21
- logger = logging.getLogger(__name__)
22
 
23
- print("Script started")
 
 
 
24
 
25
- # Check if CUDA is available and set the device
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
- print(f"Using device: {device}")
28
 
29
- # Load the Whisper model and processor
30
- whisper_model_name = "openai/whisper-small"
31
- whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
32
- whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)
33
 
34
- def download_audio_from_url(url):
35
  try:
36
- if "youtube.com" in url or "youtu.be" in url:
37
- print("Processing YouTube URL...")
38
- yt = YouTube(url)
39
- audio_stream = yt.streams.filter(only_audio=True).first()
40
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
41
- audio_stream.download(output_path=temp_file.name)
42
- audio_bytes = open(temp_file.name, "rb").read()
43
- os.unlink(temp_file.name)
44
- elif "share" in url:
45
- print("Processing shareable link...")
46
- response = requests.get(url)
47
- soup = BeautifulSoup(response.content, 'html.parser')
48
- video_tag = soup.find('video')
49
- if video_tag and 'src' in video_tag.attrs:
50
- video_url = video_tag['src']
51
- print(f"Extracted video URL: {video_url}")
52
- else:
53
- raise ValueError("Direct video URL not found in the shareable link.")
54
- response = requests.get(video_url)
55
- audio_bytes = response.content
56
- else:
57
- print(f"Downloading video from URL: {url}")
58
- response = requests.get(url)
59
- audio_bytes = response.content
60
-
61
- print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
62
- return audio_bytes
63
- except Exception as e:
64
- print(f"Error in download_audio_from_url: {str(e)}")
65
- raise
66
-
67
- def transcribe_audio(audio_file):
68
- try:
69
- logger.info("Loading audio file...")
70
- audio_input, sr = librosa.load(audio_file, sr=16000)
71
- audio_input = audio_input.astype(np.float32)
72
- logger.info(f"Audio duration: {len(audio_input) / sr:.2f} seconds")
73
-
74
- chunk_length = 30 * sr
75
- overlap = 5 * sr
76
- transcriptions = []
77
-
78
- logger.info("Starting transcription...")
79
- for i in range(0, len(audio_input), chunk_length - overlap):
80
- chunk = audio_input[i:i+chunk_length]
81
- input_features = whisper_processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
82
- predicted_ids = whisper_model.generate(input_features)
83
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
84
- transcriptions.extend(transcription)
85
- logger.info(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")
86
-
87
- full_transcription = " ".join(transcriptions)
88
- logger.info(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")
89
-
90
- return full_transcription
91
- except Exception as e:
92
- logger.error(f"Error in transcribe_audio: {str(e)}")
93
- raise
94
-
95
- def transcribe_video(url):
96
  try:
97
- logger.info(f"Attempting to download audio from URL: {url}")
98
- audio_bytes = download_audio_from_url(url)
99
- logger.info(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
100
-
101
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
102
- AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
103
- transcript = transcribe_audio(temp_audio.name)
104
-
105
- os.unlink(temp_audio.name)
106
-
107
- if len(transcript) < 10:
108
- raise ValueError("Transcription too short, possibly failed")
109
-
110
- logger.info(f"Transcription successful. Length: {len(transcript)} characters")
111
- logger.info(f"First 100 characters of transcript: {transcript[:100]}...")
112
-
113
- return transcript
114
- except Exception as e:
115
- error_message = f"An error occurred in transcribe_video: {str(e)}"
116
- logger.error(error_message)
117
- return error_message
118
 
119
- app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
120
 
121
  app.layout = dbc.Container([
122
- dbc.Row([
123
- dbc.Col([
124
- html.H1("Video Transcription", className="text-center mb-4"),
125
- html.Div("If you can see this, the app is working!", className="text-center mb-4"),
126
- dbc.Card([
127
- dbc.CardBody([
128
- dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
129
- dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
130
- dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
131
- html.Div([
132
- dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
133
- dcc.Download(id="download-transcript")
134
- ])
135
- ])
136
- ])
137
- ], width=12)
 
 
 
 
 
 
 
 
 
 
 
138
  ])
139
- ], fluid=True)
140
 
141
  @app.callback(
142
  Output("transcription-output", "children"),
143
  Output("download-button", "style"),
144
  Input("transcribe-button", "n_clicks"),
145
- State("video-url", "value"),
 
 
146
  prevent_initial_call=True
147
  )
148
- def update_transcription(n_clicks, url):
149
- if not url:
150
  raise PreventUpdate
151
 
152
  def transcribe():
153
  try:
154
- transcript = transcribe_video(url)
155
- logger.info(f"Transcription completed. Result length: {len(transcript)} characters")
156
- return transcript
157
  except Exception as e:
158
- logger.exception("Error in transcription:")
159
  return f"An error occurred: {str(e)}"
160
 
161
- # Run transcription in a separate thread
162
  thread = threading.Thread(target=transcribe)
163
  thread.start()
164
  thread.join(timeout=600) # 10 minutes timeout
165
 
166
  if thread.is_alive():
167
- logger.warning("Transcription timed out after 10 minutes")
168
  return "Transcription timed out after 10 minutes", {'display': 'none'}
169
 
170
  transcript = getattr(thread, 'result', "Transcription failed")
171
- logger.info(f"Final transcript length: {len(transcript)} characters")
172
 
173
  if transcript and not transcript.startswith("An error occurred"):
174
- logger.info("Transcription successful, returning result")
175
  return dbc.Card([
176
  dbc.CardBody([
177
  html.H5("Transcription Result"),
@@ -179,7 +132,6 @@ def update_transcription(n_clicks, url):
179
  ])
180
  ]), {'display': 'block'}
181
  else:
182
- logger.error(f"Transcription failed: {transcript}")
183
  return transcript, {'display': 'none'}
184
 
185
  @app.callback(
@@ -196,6 +148,6 @@ def download_transcript(n_clicks, transcription_output):
196
  return dict(content=transcript, filename="transcript.txt")
197
 
198
  if __name__ == '__main__':
199
- logger.info("Starting the Dash application...")
200
  app.run(debug=True, host='0.0.0.0', port=7860)
201
- logger.info("Dash application has finished running.")
 
1
  import io
 
 
 
 
 
2
  import os
3
+ import tempfile
4
+ import threading
5
+ import base64
6
+ from urllib.parse import urlparse
7
+
8
  import dash
9
  from dash import dcc, html, Input, Output, State
10
  import dash_bootstrap_components as dbc
11
  from dash.exceptions import PreventUpdate
 
 
 
 
 
 
 
 
 
12
 
13
+ import requests
14
+ from pytube import YouTube
15
+ from pydub import AudioSegment
16
+ import openai
17
 
18
+ # Initialize the Dash app
19
+ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
 
20
 
21
+ # Retrieve the OpenAI API key from Hugging Face Spaces
22
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
 
 
23
 
24
+ def is_valid_url(url):
25
  try:
26
+ result = urlparse(url)
27
+ return all([result.scheme, result.netloc])
28
+ except ValueError:
29
+ return False
30
+
31
+ def download_audio(url):
32
+ if "youtube.com" in url or "youtu.be" in url:
33
+ yt = YouTube(url)
34
+ audio_stream = yt.streams.filter(only_audio=True).first()
35
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
36
+ audio_stream.download(output_path=os.path.dirname(temp_file.name), filename=temp_file.name)
37
+ return temp_file.name
38
+ else:
39
+ response = requests.get(url)
40
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
41
+ temp_file.write(response.content)
42
+ return temp_file.name
43
+
44
+ def transcribe_audio(file_path):
45
+ with open(file_path, "rb") as audio_file:
46
+ transcript = openai.Audio.transcribe("whisper-1", audio_file)
47
+ return transcript["text"]
48
+
49
+ def process_audio(contents, filename, url):
50
+ if contents:
51
+ content_type, content_string = contents.split(',')
52
+ decoded = base64.b64decode(content_string)
53
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(filename)[1]) as temp_file:
54
+ temp_file.write(decoded)
55
+ temp_file_path = temp_file.name
56
+ elif url:
57
+ temp_file_path = download_audio(url)
58
+ else:
59
+ raise ValueError("No input provided")
60
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  try:
62
+ transcript = transcribe_audio(temp_file_path)
63
+ finally:
64
+ os.unlink(temp_file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ return transcript
67
 
68
  app.layout = dbc.Container([
69
+ html.H1("Audio Transcription App", className="text-center my-4"),
70
+ dbc.Card([
71
+ dbc.CardBody([
72
+ dcc.Upload(
73
+ id='upload-audio',
74
+ children=html.Div([
75
+ 'Drag and Drop or ',
76
+ html.A('Select Audio File')
77
+ ]),
78
+ style={
79
+ 'width': '100%',
80
+ 'height': '60px',
81
+ 'lineHeight': '60px',
82
+ 'borderWidth': '1px',
83
+ 'borderStyle': 'dashed',
84
+ 'borderRadius': '5px',
85
+ 'textAlign': 'center',
86
+ 'margin': '10px'
87
+ },
88
+ multiple=False
89
+ ),
90
+ dbc.Input(id="audio-url", type="text", placeholder="Enter audio URL or YouTube link", className="my-3"),
91
+ dbc.Button("Transcribe", id="transcribe-button", color="primary", className="w-100 mb-3"),
92
+ dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
93
+ dbc.Button("Download Transcript", id="download-button", color="secondary", className="w-100 mt-3", style={'display': 'none'}),
94
+ dcc.Download(id="download-transcript")
95
+ ])
96
  ])
97
+ ])
98
 
99
  @app.callback(
100
  Output("transcription-output", "children"),
101
  Output("download-button", "style"),
102
  Input("transcribe-button", "n_clicks"),
103
+ State("upload-audio", "contents"),
104
+ State("upload-audio", "filename"),
105
+ State("audio-url", "value"),
106
  prevent_initial_call=True
107
  )
108
+ def update_transcription(n_clicks, contents, filename, url):
109
+ if not contents and not url:
110
  raise PreventUpdate
111
 
112
  def transcribe():
113
  try:
114
+ return process_audio(contents, filename, url)
 
 
115
  except Exception as e:
 
116
  return f"An error occurred: {str(e)}"
117
 
 
118
  thread = threading.Thread(target=transcribe)
119
  thread.start()
120
  thread.join(timeout=600) # 10 minutes timeout
121
 
122
  if thread.is_alive():
 
123
  return "Transcription timed out after 10 minutes", {'display': 'none'}
124
 
125
  transcript = getattr(thread, 'result', "Transcription failed")
 
126
 
127
  if transcript and not transcript.startswith("An error occurred"):
 
128
  return dbc.Card([
129
  dbc.CardBody([
130
  html.H5("Transcription Result"),
 
132
  ])
133
  ]), {'display': 'block'}
134
  else:
 
135
  return transcript, {'display': 'none'}
136
 
137
  @app.callback(
 
148
  return dict(content=transcript, filename="transcript.txt")
149
 
150
  if __name__ == '__main__':
151
+ print("Starting the Dash application...")
152
  app.run(debug=True, host='0.0.0.0', port=7860)
153
+ print("Dash application has finished running.")