Spaces:
Sleeping
Sleeping
import gradio as gr | |
import whisperx | |
import torch | |
import librosa | |
import logging | |
import os | |
import time | |
import numpy as np | |
import requests | |
import random | |
import string | |
import json | |
import pathlib | |
import tempfile | |
# ------------------------------- | |
# Vocal Extraction Function | |
# ------------------------------- | |
def get_vocals(input_file): | |
try: | |
session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
file_id = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
file_len = 0 | |
file_content = pathlib.Path(input_file).read_bytes() | |
file_len = len(file_content) | |
r = requests.post( | |
f'https://politrees-audio-separator-uvr.hf.space/gradio_api/upload?upload_id={file_id}', | |
files={'files': open(input_file, 'rb')} | |
) | |
json_data = r.json() | |
headers = { | |
'accept': '*/*', | |
'accept-language': 'en-US,en;q=0.5', | |
'content-type': 'application/json', | |
'origin': 'https://politrees-audio-separator-uvr.hf.space', | |
'priority': 'u=1, i', | |
'referer': 'https://politrees-audio-separator-uvr.hf.space/?__theme=system', | |
'sec-ch-ua': '"Not(A:Brand";v="99", "Brave";v="133", "Chromium";v="133"', | |
'sec-ch-ua-mobile': '?0', | |
'sec-ch-ua-platform': '"Windows"', | |
'sec-fetch-dest': 'empty', | |
'sec-fetch-mode': 'cors', | |
'sec-fetch-site': 'same-origin', | |
'sec-fetch-storage-access': 'none', | |
'sec-gpc': '1', | |
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36', | |
} | |
params = { | |
'__theme': 'system', | |
} | |
json_payload = { | |
'data': [ | |
{ | |
'path': json_data[0], | |
'url': 'https://politrees-audio-separator-uvr.hf.space/gradio_api/file=' + json_data[0], | |
'orig_name': pathlib.Path(input_file).name, | |
'size': file_len, | |
'mime_type': 'audio/wav', | |
'meta': { | |
'_type': 'gradio.FileData', | |
}, | |
}, | |
'MelBand Roformer | Vocals by Kimberley Jensen', | |
256, | |
False, | |
5, | |
0, | |
'/tmp/audio-separator-models/', | |
'output', | |
'wav', | |
0.9, | |
0, | |
1, | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
'NAME_(STEM)_MODEL', | |
], | |
'event_data': None, | |
'fn_index': 5, | |
'trigger_id': 28, | |
'session_hash': session_hash, | |
} | |
response = requests.post( | |
'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/join', | |
params=params, | |
headers=headers, | |
json=json_payload, | |
) | |
max_retries = 5 | |
retry_delay = 5 | |
retry_count = 0 | |
while retry_count < max_retries: | |
try: | |
print(f"Connecting to stream... Attempt {retry_count + 1}") | |
r = requests.get( | |
f'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/data?session_hash={session_hash}', | |
stream=True | |
) | |
if r.status_code != 200: | |
raise Exception(f"Failed to connect: HTTP {r.status_code}") | |
print("Connected successfully.") | |
for line in r.iter_lines(): | |
if line: | |
json_resp = json.loads(line.decode('utf-8').replace('data: ', '')) | |
print(json_resp) | |
if 'process_completed' in json_resp['msg']: | |
print("Process completed.") | |
output_url = json_resp['output']['data'][1]['url'] | |
print(f"Output URL: {output_url}") | |
return output_url | |
print("Stream ended prematurely. Reconnecting...") | |
except Exception as e: | |
print(f"Error occurred: {e}. Retrying...") | |
retry_count += 1 | |
time.sleep(retry_delay) | |
print("Max retries reached. Exiting.") | |
return None | |
except Exception as ex: | |
print(f"Unexpected error in get_vocals: {ex}") | |
return None | |
# ------------------------------- | |
# Logging and Model Setup | |
# ------------------------------- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("whisperx_app") | |
device = "cpu" | |
compute_type = "int8" | |
torch.set_num_threads(os.cpu_count()) | |
models = { | |
"tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'), | |
"base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'), | |
"small": whisperx.load_model("small", device, compute_type=compute_type, vad_method='silero'), | |
"large": whisperx.load_model("large", device, compute_type=compute_type, vad_method='silero'), | |
"large-v2": whisperx.load_model("large-v2", device, compute_type=compute_type, vad_method='silero'), | |
"large-v3": whisperx.load_model("large-v3", device, compute_type=compute_type, vad_method='silero'), | |
} | |
def split_audio_by_pause(audio, sr, pause_threshold, top_db=30, energy_threshold=0.03): | |
intervals = librosa.effects.split(audio, top_db=top_db) | |
merged_intervals = [] | |
current_start, current_end = intervals[0] | |
for start, end in intervals[1:]: | |
gap_duration = (start - current_end) / sr | |
if gap_duration < pause_threshold: | |
current_end = end | |
else: | |
merged_intervals.append((current_start, current_end)) | |
current_start, current_end = start, end | |
merged_intervals.append((current_start, current_end)) | |
# Filter out segments with low average RMS energy | |
filtered_intervals = [] | |
for start, end in merged_intervals: | |
segment = audio[start:end] | |
rms = np.mean(librosa.feature.rms(y=segment)) | |
if rms >= energy_threshold: | |
filtered_intervals.append((start, end)) | |
return filtered_intervals | |
def seconds_to_srt_time(seconds): | |
msec_total = int(round(seconds * 1000)) | |
hours, msec_remainder = divmod(msec_total, 3600 * 1000) | |
minutes, msec_remainder = divmod(msec_remainder, 60 * 1000) | |
sec, msec = divmod(msec_remainder, 1000) | |
return f"{hours:02d}:{minutes:02d}:{sec:02d},{msec:03d}" | |
# ------------------------------- | |
# Main Transcription Function | |
# ------------------------------- | |
def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0, vocal_extraction=False, language="en"): | |
start_time = time.time() | |
final_result = "" | |
debug_log = [] | |
srt_entries = [] | |
try: | |
# If vocal extraction is enabled, process the file first | |
if vocal_extraction: | |
debug_log.append("Vocal extraction enabled; processing input file for vocals...") | |
extracted_url = get_vocals(audio_file) | |
if extracted_url is not None: | |
debug_log.append("Vocal extraction succeeded; downloading extracted audio...") | |
response = requests.get(extracted_url) | |
if response.status_code == 200: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
tmp.write(response.content) | |
audio_file = tmp.name | |
debug_log.append("Extracted audio downloaded and saved for transcription.") | |
else: | |
debug_log.append("Failed to download extracted audio; proceeding with original file.") | |
else: | |
debug_log.append("Vocal extraction failed; proceeding with original audio.") | |
# Load audio file at 16kHz | |
audio, sr = librosa.load(audio_file, sr=16000) | |
debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds long at {sr} Hz") | |
# Select the model and set batch size | |
model = models[model_size] | |
batch_size = 8 if model_size == "tiny" else 4 | |
# Use provided language if set; otherwise, use language detection. | |
if language: | |
transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
else: | |
transcript = model.transcribe(audio, batch_size=batch_size) | |
language = transcript.get("language", "unknown") | |
# Load alignment model using the specified language | |
model_a, metadata = whisperx.load_align_model(language_code=language, device=device) | |
# If pause_threshold > 0, split audio and process segments individually | |
if pause_threshold > 0: | |
segments = split_audio_by_pause(audio, sr, pause_threshold) | |
debug_log.append(f"Audio split into {len(segments)} segment(s) using a pause threshold of {pause_threshold}s") | |
for seg_idx, (seg_start, seg_end) in enumerate(segments): | |
audio_segment = audio[seg_start:seg_end] | |
seg_duration = (seg_end - seg_start) / sr | |
debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s") | |
seg_transcript = model.transcribe(audio_segment, batch_size=batch_size, language=language) | |
seg_aligned = whisperx.align( | |
seg_transcript["segments"], model_a, metadata, audio_segment, device | |
) | |
for segment in seg_aligned["segments"]: | |
for word in segment["words"]: | |
adjusted_start = word['start'] + seg_start/sr | |
adjusted_end = word['end'] + seg_start/sr | |
srt_entries.append({ | |
'start': adjusted_start, | |
'end': adjusted_end, | |
'word': word['word'].strip() | |
}) | |
#final_result += f"[{adjusted_start:5.2f}s-{adjusted_end:5.2f}s] {word['word']}\n" | |
else: | |
# Process the entire audio without splitting | |
transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
aligned = whisperx.align( | |
transcript["segments"], model_a, metadata, audio, device | |
) | |
for segment in aligned["segments"]: | |
for word in segment["words"]: | |
#final_result += f"[{word['start']:5.2f}s-{word['end']:5.2f}s] {word['word']}\n" | |
srt_entries.append({ | |
'start': word['start'], | |
'end': word['end'], | |
'word': word['word'].strip() | |
}) | |
srt_content = [] | |
for idx, entry in enumerate(srt_entries, start=1): | |
start_time_srt = seconds_to_srt_time(entry['start']) | |
end_time_srt = seconds_to_srt_time(entry['end']) | |
srt_content.append( | |
f"{idx}\n" | |
f"{start_time_srt} --> {end_time_srt}\n" | |
f"{entry['word']}\n" | |
) | |
final_result = "\n".join(srt_content) | |
debug_log.append(f"Language used: {language}") | |
debug_log.append(f"Batch size: {batch_size}") | |
debug_log.append(f"Processed in {time.time()-start_time:.2f}s") | |
except Exception as e: | |
logger.error("Error during transcription:", exc_info=True) | |
final_result = "Error occurred during transcription" | |
debug_log.append(f"ERROR: {str(e)}") | |
if debug: | |
return final_result, "\n".join(debug_log) | |
else: | |
return final_result, "" | |
# ------------------------------- | |
# Gradio Interface | |
# ------------------------------- | |
with gr.Blocks(title="WhisperX CPU Transcription") as demo: | |
gr.Markdown("# WhisperX CPU Transcription with Vocal Extraction Option") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
label="Upload Audio File", | |
type="filepath", | |
sources=["upload", "microphone"], | |
interactive=True, | |
) | |
model_selector = gr.Dropdown( | |
choices=list(models.keys()), | |
value="base", | |
label="Model Size", | |
interactive=True, | |
) | |
pause_threshold_slider = gr.Slider( | |
minimum=0, maximum=5, step=0.1, value=0, | |
label="Pause Threshold (seconds)", | |
interactive=True, | |
info="Set a pause duration threshold. Audio pauses longer than this will be used to split the audio into segments." | |
) | |
vocal_extraction_checkbox = gr.Checkbox( | |
label="Extract Vocals (improves accuracy on noisy audio)", | |
value=False | |
) | |
language_input = gr.Textbox( | |
label="Language Code (e.g., en, es, fr)", | |
placeholder="Enter language code", | |
value="en" | |
) | |
debug_checkbox = gr.Checkbox(label="Enable Debug Mode", value=False) | |
transcribe_btn = gr.Button("Transcribe", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Transcription Output", | |
lines=20, | |
placeholder="Transcription will appear here..." | |
) | |
debug_output = gr.Textbox( | |
label="Debug Information", | |
lines=10, | |
placeholder="Debug logs will appear here...", | |
visible=False, | |
) | |
def toggle_debug(debug_enabled): | |
return gr.update(visible=debug_enabled) | |
debug_checkbox.change( | |
toggle_debug, | |
inputs=[debug_checkbox], | |
outputs=[debug_output] | |
) | |
transcribe_btn.click( | |
transcribe, | |
inputs=[audio_input, model_selector, debug_checkbox, pause_threshold_slider, vocal_extraction_checkbox, language_input], | |
outputs=[output_text, debug_output] | |
) | |
# ------------------------------- | |
# Launch the App | |
# ------------------------------- | |
if __name__ == "__main__": | |
demo.queue(max_size=4).launch() | |