sdafd commited on
Commit
20b9e25
·
verified ·
1 Parent(s): 0a82387

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import random
5
+ import string
6
+ import pathlib
7
+ import tempfile
8
+ import logging
9
+
10
+ import torch
11
+ import whisperx
12
+ import librosa
13
+ import numpy as np
14
+ import requests
15
+
16
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
17
+ from fastapi.responses import JSONResponse
18
+
19
+ app = FastAPI(title="WhisperX API")
20
+
21
+ # -------------------------------
22
+ # Logging and Model Setup
23
+ # -------------------------------
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger("whisperx_api")
26
+
27
+ device = "cpu"
28
+ compute_type = "int8"
29
+ torch.set_num_threads(os.cpu_count())
30
+
31
+ # Pre-load models for different sizes
32
+ models = {
33
+ "tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'),
34
+ "base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'),
35
+ "small": whisperx.load_model("small", device, compute_type=compute_type, vad_method='silero'),
36
+ "large": whisperx.load_model("large", device, compute_type=compute_type, vad_method='silero'),
37
+ "large-v2": whisperx.load_model("large-v2", device, compute_type=compute_type, vad_method='silero'),
38
+ "large-v3": whisperx.load_model("large-v3", device, compute_type=compute_type, vad_method='silero'),
39
+ }
40
+
41
+ def seconds_to_srt_time(seconds: float) -> str:
42
+ """Convert seconds (float) into SRT timestamp format (HH:MM:SS,mmm)."""
43
+ hours = int(seconds // 3600)
44
+ minutes = int((seconds % 3600) // 60)
45
+ secs = int(seconds % 60)
46
+ millis = int((seconds - int(seconds)) * 1000)
47
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
48
+
49
+ # -------------------------------
50
+ # Vocal Extraction Function
51
+ # -------------------------------
52
+ def get_vocals(input_file):
53
+ try:
54
+ session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11))
55
+ file_id = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11))
56
+ file_content = pathlib.Path(input_file).read_bytes()
57
+ file_len = len(file_content)
58
+ r = requests.post(
59
+ f'https://politrees-audio-separator-uvr.hf.space/gradio_api/upload?upload_id={file_id}',
60
+ files={'files': open(input_file, 'rb')}
61
+ )
62
+ json_data = r.json()
63
+
64
+ headers = {
65
+ 'accept': '*/*',
66
+ 'accept-language': 'en-US,en;q=0.5',
67
+ 'content-type': 'application/json',
68
+ 'origin': 'https://politrees-audio-separator-uvr.hf.space',
69
+ 'priority': 'u=1, i',
70
+ 'referer': 'https://politrees-audio-separator-uvr.hf.space/?__theme=system',
71
+ 'sec-ch-ua': '"Not(A:Brand";v="99", "Brave";v="133", "Chromium";v="133"',
72
+ 'sec-ch-ua-mobile': '?0',
73
+ 'sec-ch-ua-platform': '"Windows"',
74
+ 'sec-fetch-dest': 'empty',
75
+ 'sec-fetch-mode': 'cors',
76
+ 'sec-fetch-site': 'same-origin',
77
+ 'sec-fetch-storage-access': 'none',
78
+ 'sec-gpc': '1',
79
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36',
80
+ }
81
+
82
+ params = {
83
+ '__theme': 'system',
84
+ }
85
+
86
+ json_payload = {
87
+ 'data': [
88
+ {
89
+ 'path': json_data[0],
90
+ 'url': 'https://politrees-audio-separator-uvr.hf.space/gradio_api/file=' + json_data[0],
91
+ 'orig_name': pathlib.Path(input_file).name,
92
+ 'size': file_len,
93
+ 'mime_type': 'audio/wav',
94
+ 'meta': {'_type': 'gradio.FileData'},
95
+ },
96
+ 'MelBand Roformer | Vocals by Kimberley Jensen',
97
+ 256,
98
+ False,
99
+ 5,
100
+ 0,
101
+ '/tmp/audio-separator-models/',
102
+ 'output',
103
+ 'wav',
104
+ 0.9,
105
+ 0,
106
+ 1,
107
+ 'NAME_(STEM)_MODEL',
108
+ 'NAME_(STEM)_MODEL',
109
+ 'NAME_(STEM)_MODEL',
110
+ 'NAME_(STEM)_MODEL',
111
+ 'NAME_(STEM)_MODEL',
112
+ 'NAME_(STEM)_MODEL',
113
+ 'NAME_(STEM)_MODEL',
114
+ ],
115
+ 'event_data': None,
116
+ 'fn_index': 5,
117
+ 'trigger_id': 28,
118
+ 'session_hash': session_hash,
119
+ }
120
+
121
+ response = requests.post(
122
+ 'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/join',
123
+ params=params,
124
+ headers=headers,
125
+ json=json_payload,
126
+ )
127
+
128
+ max_retries = 5
129
+ retry_delay = 5
130
+ retry_count = 0
131
+ while retry_count < max_retries:
132
+ try:
133
+ logger.info(f"Connecting to stream... Attempt {retry_count + 1}")
134
+ r = requests.get(
135
+ f'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/data?session_hash={session_hash}',
136
+ stream=True
137
+ )
138
+ if r.status_code != 200:
139
+ raise Exception(f"Failed to connect: HTTP {r.status_code}")
140
+ logger.info("Connected successfully.")
141
+ for line in r.iter_lines():
142
+ if line:
143
+ json_resp = json.loads(line.decode('utf-8').replace('data: ', ''))
144
+ logger.info(json_resp)
145
+ if 'process_completed' in json_resp['msg']:
146
+ logger.info("Process completed.")
147
+ output_url = json_resp['output']['data'][1]['url']
148
+ logger.info(f"Output URL: {output_url}")
149
+ return output_url
150
+ logger.info("Stream ended prematurely. Reconnecting...")
151
+ except Exception as e:
152
+ logger.error(f"Error occurred: {e}. Retrying...")
153
+ retry_count += 1
154
+ time.sleep(retry_delay)
155
+ logger.error("Max retries reached. Exiting.")
156
+ return None
157
+ except Exception as ex:
158
+ logger.error(f"Unexpected error in get_vocals: {ex}")
159
+ return None
160
+
161
+ def split_audio_by_pause(audio, sr, pause_threshold, top_db=30, energy_threshold=0.03):
162
+ intervals = librosa.effects.split(audio, top_db=top_db)
163
+ merged_intervals = []
164
+ current_start, current_end = intervals[0]
165
+ for start, end in intervals[1:]:
166
+ gap_duration = (start - current_end) / sr
167
+ if gap_duration < pause_threshold:
168
+ current_end = end
169
+ else:
170
+ merged_intervals.append((current_start, current_end))
171
+ current_start, current_end = start, end
172
+ merged_intervals.append((current_start, current_end))
173
+ # Filter out segments with low average RMS energy
174
+ filtered_intervals = []
175
+ for start, end in merged_intervals:
176
+ segment = audio[start:end]
177
+ rms = np.mean(librosa.feature.rms(y=segment))
178
+ if rms >= energy_threshold:
179
+ filtered_intervals.append((start, end))
180
+ return filtered_intervals
181
+
182
+ # -------------------------------
183
+ # Main Transcription Function
184
+ # -------------------------------
185
+ def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0, vocal_extraction=False, language="en"):
186
+ start_time = time.time()
187
+ srt_output = ""
188
+ debug_log = []
189
+ subtitle_index = 1
190
+
191
+ try:
192
+ # Optionally extract vocals first
193
+ if vocal_extraction:
194
+ debug_log.append("Vocal extraction enabled; processing input file for vocals...")
195
+ extracted_url = get_vocals(audio_file)
196
+ if extracted_url is not None:
197
+ debug_log.append("Vocal extraction succeeded; downloading extracted audio...")
198
+ response = requests.get(extracted_url)
199
+ if response.status_code == 200:
200
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
201
+ tmp.write(response.content)
202
+ audio_file = tmp.name
203
+ debug_log.append("Extracted audio downloaded and saved for transcription.")
204
+ else:
205
+ debug_log.append("Failed to download extracted audio; proceeding with original file.")
206
+ else:
207
+ debug_log.append("Vocal extraction failed; proceeding with original audio.")
208
+
209
+ # Load audio file (resampled to 16kHz)
210
+ audio, sr = librosa.load(audio_file, sr=16000)
211
+ debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds at {sr} Hz")
212
+
213
+ # Select model and set batch size
214
+ model = models[model_size]
215
+ batch_size = 8 if model_size == "tiny" else 4
216
+
217
+ # Transcribe using specified language (or auto-detect)
218
+ if language:
219
+ transcript = model.transcribe(audio, batch_size=batch_size, language=language)
220
+ else:
221
+ transcript = model.transcribe(audio, batch_size=batch_size)
222
+ language = transcript.get("language", "unknown")
223
+
224
+ # Load alignment model for the given language
225
+ model_a, metadata = whisperx.load_align_model(language_code=language, device=device)
226
+
227
+ if pause_threshold > 0:
228
+ segments = split_audio_by_pause(audio, sr, pause_threshold)
229
+ debug_log.append(f"Audio split into {len(segments)} segment(s) using pause threshold of {pause_threshold}s")
230
+ for seg_idx, (seg_start, seg_end) in enumerate(segments):
231
+ audio_segment = audio[seg_start:seg_end]
232
+ seg_duration = (seg_end - seg_start) / sr
233
+ debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s")
234
+ seg_transcript = model.transcribe(audio_segment, batch_size=batch_size, language=language)
235
+ seg_aligned = whisperx.align(
236
+ seg_transcript["segments"], model_a, metadata, audio_segment, device
237
+ )
238
+ for segment in seg_aligned["segments"]:
239
+ for word in segment["words"]:
240
+ adjusted_start = word['start'] + seg_start/sr
241
+ adjusted_end = word['end'] + seg_start/sr
242
+ start_timestamp = seconds_to_srt_time(adjusted_start)
243
+ end_timestamp = seconds_to_srt_time(adjusted_end)
244
+ srt_output += f"{subtitle_index}\n{start_timestamp} --> {end_timestamp}\n{word['word']}\n\n"
245
+ subtitle_index += 1
246
+ else:
247
+ # Process the entire audio without splitting
248
+ transcript = model.transcribe(audio, batch_size=batch_size, language=language)
249
+ aligned = whisperx.align(
250
+ transcript["segments"], model_a, metadata, audio, device
251
+ )
252
+ for segment in aligned["segments"]:
253
+ for word in segment["words"]:
254
+ start_timestamp = seconds_to_srt_time(word['start'])
255
+ end_timestamp = seconds_to_srt_time(word['end'])
256
+ srt_output += f"{subtitle_index}\n{start_timestamp} --> {end_timestamp}\n{word['word']}\n\n"
257
+ subtitle_index += 1
258
+
259
+ debug_log.append(f"Language used: {language}")
260
+ debug_log.append(f"Batch size: {batch_size}")
261
+ debug_log.append(f"Processed in {time.time()-start_time:.2f}s")
262
+
263
+ except Exception as e:
264
+ logger.error("Error during transcription:", exc_info=True)
265
+ srt_output = "Error occurred during transcription"
266
+ debug_log.append(f"ERROR: {str(e)}")
267
+
268
+ if debug:
269
+ return srt_output, "\n".join(debug_log)
270
+ return srt_output
271
+
272
+ # -------------------------------
273
+ # FastAPI Endpoints
274
+ # -------------------------------
275
+ @app.post("/transcribe")
276
+ async def transcribe_endpoint(
277
+ audio_file: UploadFile = File(...),
278
+ model_size: str = Form("base"),
279
+ debug: bool = Form(False),
280
+ pause_threshold: float = Form(0.0),
281
+ vocal_extraction: bool = Form(False),
282
+ language: str = Form("en")
283
+ ):
284
+ try:
285
+ # Save the uploaded file to a temporary location
286
+ suffix = pathlib.Path(audio_file.filename).suffix
287
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
288
+ tmp.write(await audio_file.read())
289
+ tmp_path = tmp.name
290
+
291
+ result = transcribe(tmp_path, model_size=model_size, debug=debug,
292
+ pause_threshold=pause_threshold,
293
+ vocal_extraction=vocal_extraction,
294
+ language=language)
295
+
296
+ os.remove(tmp_path)
297
+
298
+ if debug:
299
+ srt_text, debug_info = result
300
+ return JSONResponse(content={"srt": srt_text, "debug": debug_info})
301
+ else:
302
+ return JSONResponse(content={"srt": result})
303
+ except Exception as e:
304
+ logger.error(f"Error in transcribe_endpoint: {e}", exc_info=True)
305
+ raise HTTPException(status_code=500, detail="Internal server error")
306
+
307
+ @app.get("/")
308
+ async def root():
309
+ return {"message": "WhisperX API is running."}