sdafd commited on
Commit
7a3ea68
·
verified ·
1 Parent(s): 4d0bb63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -36
app.py CHANGED
@@ -6,17 +6,158 @@ import logging
6
  import os
7
  import time
8
  import numpy as np
 
 
 
 
 
 
9
 
10
- # Configure logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger("whisperx_app")
13
 
14
- # Device setup (force CPU)
15
  device = "cpu"
16
  compute_type = "int8"
17
  torch.set_num_threads(os.cpu_count())
18
 
19
- # Pre-load models
20
  models = {
21
  "tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'),
22
  "base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'),
@@ -32,7 +173,6 @@ def split_audio_by_pause(audio, sr, pause_threshold, top_db=30):
32
  Adjacent non-silent intervals are merged if the gap between them is less than the pause_threshold.
33
  Returns a list of (start_sample, end_sample) tuples.
34
  """
35
- # Get non-silent intervals based on an amplitude threshold (in dB)
36
  intervals = librosa.effects.split(audio, top_db=top_db)
37
  if intervals.size == 0:
38
  return [(0, len(audio))]
@@ -41,10 +181,8 @@ def split_audio_by_pause(audio, sr, pause_threshold, top_db=30):
41
  current_start, current_end = intervals[0]
42
 
43
  for start, end in intervals[1:]:
44
- # Compute the gap duration (in seconds) between the current interval and the next one
45
  gap_duration = (start - current_end) / sr
46
  if gap_duration < pause_threshold:
47
- # Merge intervals if gap is less than the threshold
48
  current_end = end
49
  else:
50
  merged_intervals.append((current_start, current_end))
@@ -52,62 +190,85 @@ def split_audio_by_pause(audio, sr, pause_threshold, top_db=30):
52
  merged_intervals.append((current_start, current_end))
53
  return merged_intervals
54
 
55
- def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0):
 
 
 
56
  start_time = time.time()
57
  final_result = ""
58
  debug_log = []
59
 
60
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Load audio file at 16kHz
62
  audio, sr = librosa.load(audio_file, sr=16000)
63
  debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds long at {sr} Hz")
64
 
65
- # Get the preloaded model and determine batch size
 
 
 
 
 
66
  model = models[model_size]
67
  batch_size = 8 if model_size == "tiny" else 4
68
 
69
- # If pause_threshold > 0, split audio into segments based on silence pauses
 
 
 
 
 
 
 
 
 
 
70
  if pause_threshold > 0:
71
  segments = split_audio_by_pause(audio, sr, pause_threshold)
72
  debug_log.append(f"Audio split into {len(segments)} segment(s) using a pause threshold of {pause_threshold}s")
73
- # Process each audio segment individually
74
  for seg_idx, (seg_start, seg_end) in enumerate(segments):
75
  audio_segment = audio[seg_start:seg_end]
76
  seg_duration = (seg_end - seg_start) / sr
77
  debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s")
78
 
79
- # Transcribe this segment
80
- transcript = model.transcribe(audio_segment, batch_size=batch_size)
81
-
82
- # Load alignment model for the detected language in this segment
83
- model_a, metadata = whisperx.load_align_model(
84
- language_code=transcript["language"], device=device
85
- )
86
- transcript_aligned = whisperx.align(
87
- transcript["segments"], model_a, metadata, audio_segment, device
88
  )
89
-
90
- # Format word-level output with adjusted timestamps (adding segment offset)
91
- for segment in transcript_aligned["segments"]:
92
  for word in segment["words"]:
93
- # Adjust start and end times by the segment's start time (in seconds)
94
  adjusted_start = word['start'] + seg_start/sr
95
  adjusted_end = word['end'] + seg_start/sr
96
  final_result += f"[{adjusted_start:5.2f}s-{adjusted_end:5.2f}s] {word['word']}\n"
97
  else:
98
  # Process the entire audio without splitting
99
- transcript = model.transcribe(audio, batch_size=batch_size)
100
- model_a, metadata = whisperx.load_align_model(
101
- language_code=transcript["language"], device=device
102
- )
103
- transcript_aligned = whisperx.align(
104
  transcript["segments"], model_a, metadata, audio, device
105
  )
106
- for segment in transcript_aligned["segments"]:
107
  for word in segment["words"]:
108
  final_result += f"[{word['start']:5.2f}s-{word['end']:5.2f}s] {word['word']}\n"
109
 
110
- debug_log.append(f"Language detected: {transcript['language']}")
111
  debug_log.append(f"Batch size: {batch_size}")
112
  debug_log.append(f"Processed in {time.time()-start_time:.2f}s")
113
 
@@ -120,9 +281,11 @@ def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0):
120
  return final_result, "\n".join(debug_log)
121
  return final_result
122
 
 
123
  # Gradio Interface
 
124
  with gr.Blocks(title="WhisperX CPU Transcription") as demo:
125
- gr.Markdown("# WhisperX CPU Transcription with Word-Level Timestamps")
126
 
127
  with gr.Row():
128
  with gr.Column():
@@ -138,13 +301,23 @@ with gr.Blocks(title="WhisperX CPU Transcription") as demo:
138
  label="Model Size",
139
  interactive=True,
140
  )
141
- # New input: pause threshold in seconds (set to 0 to disable splitting)
142
  pause_threshold_slider = gr.Slider(
143
  minimum=0, maximum=5, step=0.1, value=0,
144
  label="Pause Threshold (seconds)",
145
  interactive=True,
146
  info="Set a pause duration threshold. Audio pauses longer than this will be used to split the audio into segments."
147
  )
 
 
 
 
 
 
 
 
 
 
 
148
  debug_checkbox = gr.Checkbox(label="Enable Debug Mode", value=False)
149
  transcribe_btn = gr.Button("Transcribe", variant="primary")
150
 
@@ -152,7 +325,7 @@ with gr.Blocks(title="WhisperX CPU Transcription") as demo:
152
  output_text = gr.Textbox(
153
  label="Transcription Output",
154
  lines=20,
155
- placeholder="Transcription will appear here...",
156
  )
157
  debug_output = gr.Textbox(
158
  label="Debug Information",
@@ -171,13 +344,15 @@ with gr.Blocks(title="WhisperX CPU Transcription") as demo:
171
  outputs=[debug_output]
172
  )
173
 
174
- # Process transcription with the new pause_threshold parameter
175
  transcribe_btn.click(
176
  transcribe,
177
- inputs=[audio_input, model_selector, debug_checkbox, pause_threshold_slider],
178
  outputs=[output_text, debug_output]
179
  )
180
 
181
- # Launch configuration
 
 
182
  if __name__ == "__main__":
183
  demo.queue(max_size=4).launch()
 
6
  import os
7
  import time
8
  import numpy as np
9
+ import requests
10
+ import random
11
+ import string
12
+ import json
13
+ import pathlib
14
+ import tempfile
15
 
16
+ # -------------------------------
17
+ # Vocal Extraction Function
18
+ # -------------------------------
19
+ def get_vocals(input_file):
20
+ try:
21
+ session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11))
22
+ file_id = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11))
23
+ file_len = 0
24
+
25
+ file_content = pathlib.Path(input_file).read_bytes()
26
+ file_len = len(file_content)
27
+ r = requests.post(
28
+ f'https://politrees-audio-separator-uvr.hf.space/gradio_api/upload?upload_id={file_id}',
29
+ files={'files': open(input_file, 'rb')}
30
+ )
31
+ json_data = r.json()
32
+
33
+ headers = {
34
+ 'accept': '*/*',
35
+ 'accept-language': 'en-US,en;q=0.5',
36
+ 'content-type': 'application/json',
37
+ 'origin': 'https://politrees-audio-separator-uvr.hf.space',
38
+ 'priority': 'u=1, i',
39
+ 'referer': 'https://politrees-audio-separator-uvr.hf.space/?__theme=system',
40
+ 'sec-ch-ua': '"Not(A:Brand";v="99", "Brave";v="133", "Chromium";v="133"',
41
+ 'sec-ch-ua-mobile': '?0',
42
+ 'sec-ch-ua-platform': '"Windows"',
43
+ 'sec-fetch-dest': 'empty',
44
+ 'sec-fetch-mode': 'cors',
45
+ 'sec-fetch-site': 'same-origin',
46
+ 'sec-fetch-storage-access': 'none',
47
+ 'sec-gpc': '1',
48
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36',
49
+ }
50
+
51
+ params = {
52
+ '__theme': 'system',
53
+ }
54
+
55
+ json_payload = {
56
+ 'data': [
57
+ {
58
+ 'path': json_data[0],
59
+ 'url': 'https://politrees-audio-separator-uvr.hf.space/gradio_api/file='+json_data[0],
60
+ 'orig_name': pathlib.Path(input_file).name,
61
+ 'size': file_len,
62
+ 'mime_type': 'audio/wav',
63
+ 'meta': {
64
+ '_type': 'gradio.FileData',
65
+ },
66
+ },
67
+ 'MelBand Roformer | Vocals by Kimberley Jensen',
68
+ 256,
69
+ False,
70
+ 5,
71
+ 0,
72
+ '/tmp/audio-separator-models/',
73
+ 'output',
74
+ 'wav',
75
+ 0.9,
76
+ 0,
77
+ 1,
78
+ 'NAME_(STEM)_MODEL',
79
+ 'NAME_(STEM)_MODEL',
80
+ 'NAME_(STEM)_MODEL',
81
+ 'NAME_(STEM)_MODEL',
82
+ 'NAME_(STEM)_MODEL',
83
+ 'NAME_(STEM)_MODEL',
84
+ 'NAME_(STEM)_MODEL',
85
+ ],
86
+ 'event_data': None,
87
+ 'fn_index': 5,
88
+ 'trigger_id': 28,
89
+ 'session_hash': session_hash,
90
+ }
91
+
92
+ response = requests.post(
93
+ 'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/join',
94
+ params=params,
95
+ headers=headers,
96
+ json=json_payload,
97
+ )
98
+
99
+ max_retries = 5
100
+ retry_delay = 5
101
+ retry_count = 0
102
+ while retry_count < max_retries:
103
+ try:
104
+ print(f"Connecting to stream... Attempt {retry_count + 1}")
105
+ r = requests.get(
106
+ f'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/data?session_hash={session_hash}',
107
+ stream=True
108
+ )
109
+
110
+ if r.status_code != 200:
111
+ raise Exception(f"Failed to connect: HTTP {r.status_code}")
112
+
113
+ print("Connected successfully.")
114
+ for line in r.iter_lines():
115
+ if line:
116
+ json_resp = json.loads(line.decode('utf-8').replace('data: ', ''))
117
+ print(json_resp)
118
+ if 'process_completed' in json_resp['msg']:
119
+ print("Process completed.")
120
+ output_url = json_resp['output']['data'][1]['url']
121
+ print(f"Output URL: {output_url}")
122
+ return output_url
123
+ print("Stream ended prematurely. Reconnecting...")
124
+
125
+ except Exception as e:
126
+ print(f"Error occurred: {e}. Retrying...")
127
+
128
+ retry_count += 1
129
+ time.sleep(retry_delay)
130
+
131
+ print("Max retries reached. Exiting.")
132
+ return None
133
+ except Exception as ex:
134
+ print(f"Unexpected error in get_vocals: {ex}")
135
+ return None
136
+
137
+ # -------------------------------
138
+ # Normalization Function
139
+ # -------------------------------
140
+ def normalize_audio(audio, threshold_ratio=0.6):
141
+ """
142
+ Given an audio signal (numpy array), set to 0 any samples that are below
143
+ a given ratio of the maximum absolute amplitude. This is a simple way to
144
+ suppress relatively quieter (background) parts.
145
+ """
146
+ max_val = np.max(np.abs(audio))
147
+ threshold = threshold_ratio * max_val
148
+ normalized_audio = np.where(np.abs(audio) >= threshold, audio, 0)
149
+ return normalized_audio
150
+
151
+ # -------------------------------
152
+ # Logging and Model Setup
153
+ # -------------------------------
154
  logging.basicConfig(level=logging.INFO)
155
  logger = logging.getLogger("whisperx_app")
156
 
 
157
  device = "cpu"
158
  compute_type = "int8"
159
  torch.set_num_threads(os.cpu_count())
160
 
 
161
  models = {
162
  "tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'),
163
  "base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'),
 
173
  Adjacent non-silent intervals are merged if the gap between them is less than the pause_threshold.
174
  Returns a list of (start_sample, end_sample) tuples.
175
  """
 
176
  intervals = librosa.effects.split(audio, top_db=top_db)
177
  if intervals.size == 0:
178
  return [(0, len(audio))]
 
181
  current_start, current_end = intervals[0]
182
 
183
  for start, end in intervals[1:]:
 
184
  gap_duration = (start - current_end) / sr
185
  if gap_duration < pause_threshold:
 
186
  current_end = end
187
  else:
188
  merged_intervals.append((current_start, current_end))
 
190
  merged_intervals.append((current_start, current_end))
191
  return merged_intervals
192
 
193
+ # -------------------------------
194
+ # Main Transcription Function
195
+ # -------------------------------
196
+ def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0, vocal_extraction=False, language="en"):
197
  start_time = time.time()
198
  final_result = ""
199
  debug_log = []
200
 
201
  try:
202
+ # If vocal extraction is enabled, process the file first
203
+ if vocal_extraction:
204
+ debug_log.append("Vocal extraction enabled; processing input file for vocals...")
205
+ extracted_url = get_vocals(audio_file)
206
+ if extracted_url is not None:
207
+ debug_log.append("Vocal extraction succeeded; downloading extracted audio...")
208
+ response = requests.get(extracted_url)
209
+ if response.status_code == 200:
210
+ # Write to a temporary file
211
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
212
+ tmp.write(response.content)
213
+ audio_file = tmp.name
214
+ debug_log.append("Extracted audio downloaded and saved for transcription.")
215
+ else:
216
+ debug_log.append("Failed to download extracted audio; proceeding with original file.")
217
+ else:
218
+ debug_log.append("Vocal extraction failed; proceeding with original audio.")
219
+
220
  # Load audio file at 16kHz
221
  audio, sr = librosa.load(audio_file, sr=16000)
222
  debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds long at {sr} Hz")
223
 
224
+ # If we used vocal extraction, apply normalization to remove low-amplitude (background) parts
225
+ if vocal_extraction:
226
+ audio = normalize_audio(audio)
227
+ debug_log.append("Normalization applied to extracted audio to remove low-amplitude segments.")
228
+
229
+ # Select the model and set batch size
230
  model = models[model_size]
231
  batch_size = 8 if model_size == "tiny" else 4
232
 
233
+ # Use the provided language if set; otherwise, let the model detect the language.
234
+ if language:
235
+ transcript = model.transcribe(audio, batch_size=batch_size, language=language)
236
+ else:
237
+ transcript = model.transcribe(audio, batch_size=batch_size)
238
+ language = transcript.get("language", "unknown")
239
+
240
+ # Load alignment model using the specified/overridden language
241
+ model_a, metadata = whisperx.load_align_model(language_code=language, device=device)
242
+
243
+ # If pause_threshold > 0, split the audio and process segments individually
244
  if pause_threshold > 0:
245
  segments = split_audio_by_pause(audio, sr, pause_threshold)
246
  debug_log.append(f"Audio split into {len(segments)} segment(s) using a pause threshold of {pause_threshold}s")
 
247
  for seg_idx, (seg_start, seg_end) in enumerate(segments):
248
  audio_segment = audio[seg_start:seg_end]
249
  seg_duration = (seg_end - seg_start) / sr
250
  debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s")
251
 
252
+ seg_transcript = model.transcribe(audio_segment, batch_size=batch_size, language=language)
253
+ seg_aligned = whisperx.align(
254
+ seg_transcript["segments"], model_a, metadata, audio_segment, device
 
 
 
 
 
 
255
  )
256
+ for segment in seg_aligned["segments"]:
 
 
257
  for word in segment["words"]:
 
258
  adjusted_start = word['start'] + seg_start/sr
259
  adjusted_end = word['end'] + seg_start/sr
260
  final_result += f"[{adjusted_start:5.2f}s-{adjusted_end:5.2f}s] {word['word']}\n"
261
  else:
262
  # Process the entire audio without splitting
263
+ transcript = model.transcribe(audio, batch_size=batch_size, language=language)
264
+ aligned = whisperx.align(
 
 
 
265
  transcript["segments"], model_a, metadata, audio, device
266
  )
267
+ for segment in aligned["segments"]:
268
  for word in segment["words"]:
269
  final_result += f"[{word['start']:5.2f}s-{word['end']:5.2f}s] {word['word']}\n"
270
 
271
+ debug_log.append(f"Language used: {language}")
272
  debug_log.append(f"Batch size: {batch_size}")
273
  debug_log.append(f"Processed in {time.time()-start_time:.2f}s")
274
 
 
281
  return final_result, "\n".join(debug_log)
282
  return final_result
283
 
284
+ # -------------------------------
285
  # Gradio Interface
286
+ # -------------------------------
287
  with gr.Blocks(title="WhisperX CPU Transcription") as demo:
288
+ gr.Markdown("# WhisperX CPU Transcription with Vocal Extraction Option")
289
 
290
  with gr.Row():
291
  with gr.Column():
 
301
  label="Model Size",
302
  interactive=True,
303
  )
 
304
  pause_threshold_slider = gr.Slider(
305
  minimum=0, maximum=5, step=0.1, value=0,
306
  label="Pause Threshold (seconds)",
307
  interactive=True,
308
  info="Set a pause duration threshold. Audio pauses longer than this will be used to split the audio into segments."
309
  )
310
+ # New input for vocal extraction feature
311
+ vocal_extraction_checkbox = gr.Checkbox(
312
+ label="Extract Vocals (improves accuracy on noisy audio)",
313
+ value=False
314
+ )
315
+ # New language selection (default English)
316
+ language_input = gr.Textbox(
317
+ label="Language Code (e.g., en, es, fr)",
318
+ placeholder="Enter language code",
319
+ value="en"
320
+ )
321
  debug_checkbox = gr.Checkbox(label="Enable Debug Mode", value=False)
322
  transcribe_btn = gr.Button("Transcribe", variant="primary")
323
 
 
325
  output_text = gr.Textbox(
326
  label="Transcription Output",
327
  lines=20,
328
+ placeholder="Transcription will appear here..."
329
  )
330
  debug_output = gr.Textbox(
331
  label="Debug Information",
 
344
  outputs=[debug_output]
345
  )
346
 
347
+ # Process transcription with all new parameters
348
  transcribe_btn.click(
349
  transcribe,
350
+ inputs=[audio_input, model_selector, debug_checkbox, pause_threshold_slider, vocal_extraction_checkbox, language_input],
351
  outputs=[output_text, debug_output]
352
  )
353
 
354
+ # -------------------------------
355
+ # Launch the App
356
+ # -------------------------------
357
  if __name__ == "__main__":
358
  demo.queue(max_size=4).launch()