Spaces:
Build error
Build error
import os | |
import time | |
import json | |
import random | |
import string | |
import pathlib | |
import tempfile | |
import logging | |
import torch | |
import whisperx | |
import librosa | |
import numpy as np | |
import requests | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from fastapi.responses import JSONResponse | |
app = FastAPI(title="WhisperX API") | |
# ------------------------------- | |
# Logging and Model Setup | |
# ------------------------------- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("whisperx_api") | |
device = "cpu" | |
compute_type = "int8" | |
torch.set_num_threads(os.cpu_count()) | |
# Pre-load models for different sizes | |
models = { | |
"tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'), | |
"base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'), | |
"small": whisperx.load_model("small", device, compute_type=compute_type, vad_method='silero'), | |
"large": whisperx.load_model("large", device, compute_type=compute_type, vad_method='silero'), | |
"large-v2": whisperx.load_model("large-v2", device, compute_type=compute_type, vad_method='silero'), | |
"large-v3": whisperx.load_model("large-v3", device, compute_type=compute_type, vad_method='silero'), | |
} | |
def seconds_to_srt_time(seconds: float) -> str: | |
"""Convert seconds (float) into SRT timestamp format (HH:MM:SS,mmm).""" | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = int(seconds % 60) | |
millis = int((seconds - int(seconds)) * 1000) | |
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" | |
# ------------------------------- | |
# Vocal Extraction Function | |
# ------------------------------- | |
def get_vocals(input_file): | |
try: | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
file_id = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
file_content = pathlib.Path(input_file).read_bytes() | |
file_len = len(file_content) | |
r = requests.post( | |
f'https://politrees-audio-separator-uvr.hf.space/gradio_api/upload?upload_id={file_id}', | |
files={'files': open(input_file, 'rb')} | |
) | |
json_data = r.json() | |
headers = { | |
'accept': '*/*', | |
'accept-language': 'en-US,en;q=0.5', | |
'content-type': 'application/json', | |
'origin': 'https://politrees-audio-separator-uvr.hf.space', | |
'priority': 'u=1, i', | |
'referer': 'https://politrees-audio-separator-uvr.hf.space/?__theme=system', | |
'sec-ch-ua': '"Not(A:Brand";v="99", "Brave";v="133", "Chromium";v="133"', | |
'sec-ch-ua-mobile': '?0', | |
'sec-ch-ua-platform': '"Windows"', | |
'sec-fetch-dest': 'empty', | |
'sec-fetch-mode': 'cors', | |
'sec-fetch-site': 'same-origin', | |
'sec-fetch-storage-access': 'none', | |
'sec-gpc': '1', | |
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36', | |
} | |
params = { | |
'__theme': 'system', | |
} | |
json_payload = { | |
'data': [ | |
{ | |
'path': json_data[0], | |
'url': 'https://politrees-audio-separator-uvr.hf.space/gradio_api/file=' + json_data[0], | |
'orig_name': pathlib.Path(input_file).name, | |
'size': file_len, | |
'mime_type': 'audio/wav', | |
'meta': {'_type': 'gradio.FileData'}, | |
}, | |
'MelBand Roformer | Vocals by Kimberley Jensen', | |
256, | |
False, | |
5, | |
0, | |
'/tmp/audio-separator-models/', | |
'output', | |
'wav', | |
0.9, | |
0, | |
1, | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
], | |
'event_data': None, | |
'fn_index': 5, | |
'trigger_id': 28, | |
'session_hash': session_hash, | |
} | |
response = requests.post( | |
'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/join', | |
params=params, | |
headers=headers, | |
json=json_payload, | |
) | |
max_retries = 5 | |
retry_delay = 5 | |
retry_count = 0 | |
while retry_count < max_retries: | |
try: | |
logger.info(f"Connecting to stream... Attempt {retry_count + 1}") | |
r = requests.get( | |
f'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/data?session_hash={session_hash}', | |
stream=True | |
) | |
if r.status_code != 200: | |
raise Exception(f"Failed to connect: HTTP {r.status_code}") | |
logger.info("Connected successfully.") | |
for line in r.iter_lines(): | |
if line: | |
json_resp = json.loads(line.decode('utf-8').replace('data: ', '')) | |
logger.info(json_resp) | |
if 'process_completed' in json_resp['msg']: | |
logger.info("Process completed.") | |
output_url = json_resp['output']['data'][1]['url'] | |
logger.info(f"Output URL: {output_url}") | |
return output_url | |
logger.info("Stream ended prematurely. Reconnecting...") | |
except Exception as e: | |
logger.error(f"Error occurred: {e}. Retrying...") | |
retry_count += 1 | |
time.sleep(retry_delay) | |
logger.error("Max retries reached. Exiting.") | |
return None | |
except Exception as ex: | |
logger.error(f"Unexpected error in get_vocals: {ex}") | |
return None | |
def split_audio_by_pause(audio, sr, pause_threshold, top_db=30, energy_threshold=0.03): | |
intervals = librosa.effects.split(audio, top_db=top_db) | |
merged_intervals = [] | |
current_start, current_end = intervals[0] | |
for start, end in intervals[1:]: | |
gap_duration = (start - current_end) / sr | |
if gap_duration < pause_threshold: | |
current_end = end | |
else: | |
merged_intervals.append((current_start, current_end)) | |
current_start, current_end = start, end | |
merged_intervals.append((current_start, current_end)) | |
# Filter out segments with low average RMS energy | |
filtered_intervals = [] | |
for start, end in merged_intervals: | |
segment = audio[start:end] | |
rms = np.mean(librosa.feature.rms(y=segment)) | |
if rms >= energy_threshold: | |
filtered_intervals.append((start, end)) | |
return filtered_intervals | |
# ------------------------------- | |
# Main Transcription Function | |
# ------------------------------- | |
def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0, vocal_extraction=False, language="en"): | |
start_time = time.time() | |
srt_output = "" | |
debug_log = [] | |
subtitle_index = 1 | |
try: | |
# Optionally extract vocals first | |
if vocal_extraction: | |
debug_log.append("Vocal extraction enabled; processing input file for vocals...") | |
extracted_url = get_vocals(audio_file) | |
if extracted_url is not None: | |
debug_log.append("Vocal extraction succeeded; downloading extracted audio...") | |
response = requests.get(extracted_url) | |
if response.status_code == 200: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
tmp.write(response.content) | |
audio_file = tmp.name | |
debug_log.append("Extracted audio downloaded and saved for transcription.") | |
else: | |
debug_log.append("Failed to download extracted audio; proceeding with original file.") | |
else: | |
debug_log.append("Vocal extraction failed; proceeding with original audio.") | |
# Load audio file (resampled to 16kHz) | |
audio, sr = librosa.load(audio_file, sr=16000) | |
debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds at {sr} Hz") | |
# Select model and set batch size | |
model = models[model_size] | |
batch_size = 8 if model_size == "tiny" else 4 | |
# Transcribe using specified language (or auto-detect) | |
if language: | |
transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
else: | |
transcript = model.transcribe(audio, batch_size=batch_size) | |
language = transcript.get("language", "unknown") | |
# Load alignment model for the given language | |
model_a, metadata = whisperx.load_align_model(language_code=language, device=device) | |
if pause_threshold > 0: | |
segments = split_audio_by_pause(audio, sr, pause_threshold) | |
debug_log.append(f"Audio split into {len(segments)} segment(s) using pause threshold of {pause_threshold}s") | |
for seg_idx, (seg_start, seg_end) in enumerate(segments): | |
audio_segment = audio[seg_start:seg_end] | |
seg_duration = (seg_end - seg_start) / sr | |
debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s") | |
seg_transcript = model.transcribe(audio_segment, batch_size=batch_size, language=language) | |
seg_aligned = whisperx.align( | |
seg_transcript["segments"], model_a, metadata, audio_segment, device | |
) | |
for segment in seg_aligned["segments"]: | |
for word in segment["words"]: | |
adjusted_start = word['start'] + seg_start/sr | |
adjusted_end = word['end'] + seg_start/sr | |
start_timestamp = seconds_to_srt_time(adjusted_start) | |
end_timestamp = seconds_to_srt_time(adjusted_end) | |
srt_output += f"{subtitle_index}\n{start_timestamp} --> {end_timestamp}\n{word['word']}\n\n" | |
subtitle_index += 1 | |
else: | |
# Process the entire audio without splitting | |
transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
aligned = whisperx.align( | |
transcript["segments"], model_a, metadata, audio, device | |
) | |
for segment in aligned["segments"]: | |
for word in segment["words"]: | |
start_timestamp = seconds_to_srt_time(word['start']) | |
end_timestamp = seconds_to_srt_time(word['end']) | |
srt_output += f"{subtitle_index}\n{start_timestamp} --> {end_timestamp}\n{word['word']}\n\n" | |
subtitle_index += 1 | |
debug_log.append(f"Language used: {language}") | |
debug_log.append(f"Batch size: {batch_size}") | |
debug_log.append(f"Processed in {time.time()-start_time:.2f}s") | |
except Exception as e: | |
logger.error("Error during transcription:", exc_info=True) | |
srt_output = "Error occurred during transcription" | |
debug_log.append(f"ERROR: {str(e)}") | |
if debug: | |
return srt_output, "\n".join(debug_log) | |
return srt_output | |
# ------------------------------- | |
# FastAPI Endpoints | |
# ------------------------------- | |
async def transcribe_endpoint( | |
audio_file: UploadFile = File(...), | |
model_size: str = Form("base"), | |
debug: bool = Form(False), | |
pause_threshold: float = Form(0.0), | |
vocal_extraction: bool = Form(False), | |
language: str = Form("en") | |
): | |
try: | |
# Save the uploaded file to a temporary location | |
suffix = pathlib.Path(audio_file.filename).suffix | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
tmp.write(await audio_file.read()) | |
tmp_path = tmp.name | |
result = transcribe(tmp_path, model_size=model_size, debug=debug, | |
pause_threshold=pause_threshold, | |
vocal_extraction=vocal_extraction, | |
language=language) | |
os.remove(tmp_path) | |
if debug: | |
srt_text, debug_info = result | |
return JSONResponse(content={"srt": srt_text, "debug": debug_info}) | |
else: | |
return JSONResponse(content={"srt": result}) | |
except Exception as e: | |
logger.error(f"Error in transcribe_endpoint: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail="Internal server error") | |
async def root(): | |
return {"message": "WhisperX API is running."} | |