Spaces:
Running
on
Zero
Running
on
Zero
from nemo.collections.asr.models import ASRModel | |
import torch | |
import gradio as gr | |
import spaces | |
import gc | |
from pathlib import Path | |
from pydub import AudioSegment | |
import numpy as np | |
import os | |
import tempfile | |
import gradio.themes as gr_themes | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2" | |
model = ASRModel.from_pretrained(model_name=MODEL_NAME) | |
model.eval() | |
def get_audio_segment(audio_path, start_second, end_second): | |
if not audio_path or not Path(audio_path).exists(): | |
print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.") | |
return None | |
try: | |
start_ms = int(start_second * 1000) | |
end_ms = int(end_second * 1000) | |
start_ms = max(0, start_ms) | |
if end_ms <= start_ms: | |
print(f"Warning: End time ({end_second}s) is not after start time ({start_second}s). Adjusting end time.") | |
end_ms = start_ms + 100 | |
# Unconditionally use pydub for all supported types (.mp3, .wav, .mp4, etc) | |
audio = AudioSegment.from_file(audio_path) # pydub/ffmpeg supports most formats! | |
clipped_audio = audio[start_ms:end_ms] | |
samples = np.array(clipped_audio.get_array_of_samples()) | |
if clipped_audio.channels == 2: | |
samples = samples.reshape((-1, 2)).mean(axis=1).astype(samples.dtype) | |
frame_rate = clipped_audio.frame_rate | |
if frame_rate <= 0: | |
print(f"Warning: Invalid frame rate ({frame_rate}) detected for clipped audio.") | |
frame_rate = audio.frame_rate | |
if samples.size == 0: | |
print(f"Warning: Clipped audio resulted in empty samples array ({start_second}s to {end_second}s).") | |
return None | |
return (frame_rate, samples) | |
except FileNotFoundError: | |
print(f"Error: Audio file not found at path: {audio_path}") | |
return None | |
except Exception as e: | |
print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}") | |
return None | |
def seconds_to_srt_ts(seconds: float): | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = int(seconds % 60) | |
ms = int((seconds - int(seconds)) * 1000) | |
return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}" | |
def get_transcripts_and_raw_times(file_path): | |
if not file_path: | |
gr.Error("No file path provided for transcription.", duration=None) | |
return [], [], None, gr.DownloadButton(visible=False) | |
vis_data = [["N/A", "N/A", "Processing failed"]] | |
raw_times_data = [[0.0, 0.0]] | |
temp_files = [] # To track all temporary files created | |
srt_file_path = None | |
original_path_name = Path(file_path).name | |
try: | |
try: | |
gr.Info(f"Loading file: {original_path_name}", duration=2) | |
# pydub/ffmpeg supports .mp3, .wav, .mp4, .m4a, .aac, etc. | |
audio = AudioSegment.from_file(file_path) # pydub handles mp4 via ffmpeg! | |
except Exception as load_e: | |
gr.Error(f"Failed to load file {original_path_name}: {load_e}", duration=None) | |
return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False) | |
# Process audio for transcription | |
try: | |
target_sr = 16000 | |
if audio.frame_rate != target_sr: | |
audio = audio.set_frame_rate(target_sr) | |
if audio.channels == 2: | |
audio = audio.set_channels(1) | |
elif audio.channels > 2: | |
gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None) | |
return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False) | |
except Exception as process_e: | |
gr.Error(f"Failed to process audio: {process_e}", duration=None) | |
return [["Error", "Error", "Audio processing failed"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False) | |
# Check if audio is longer than chunk size | |
audio_length_sec = len(audio) / 1000.0 # pydub uses milliseconds | |
# Configuration for chunking - 10 minutes works on a 24GB RTX3090. | |
chunk_size_sec = 10 * 60 | |
overlap_sec = 5 # 5 seconds overlap between chunks | |
# Convert to milliseconds for pydub | |
chunk_size_ms = chunk_size_sec * 1000 | |
overlap_ms = overlap_sec * 1000 | |
# Determine if we need chunking | |
need_chunking = audio_length_sec > chunk_size_sec | |
# Initialize list to hold ALL segments from ALL chunks | |
all_segments = [] | |
if need_chunking: | |
# Calculate number of chunks | |
total_chunks = max(1, int(np.ceil(audio_length_sec / chunk_size_sec))) | |
print(f"Audio length: {audio_length_sec:.2f} seconds ({audio_length_sec/60:.2f} minutes)") | |
print(f"Chunk size: {chunk_size_sec} seconds ({chunk_size_sec/60:.2f} minutes)") | |
print(f"Total chunks needed: {total_chunks}") | |
gr.Info(f"Audio is {audio_length_sec/60:.1f} minutes long. Processing in {total_chunks} chunks...", duration=3) | |
# Process each chunk | |
for i in range(total_chunks): | |
# Calculate chunk boundaries in milliseconds | |
chunk_start_ms = i * chunk_size_ms | |
chunk_end_ms = min(len(audio), (i + 1) * chunk_size_ms) | |
# Add overlap except for first and last chunks | |
if i > 0: | |
chunk_start_ms -= overlap_ms # Extend start earlier | |
if i < total_chunks - 1 and chunk_end_ms + overlap_ms <= len(audio): | |
chunk_end_ms += overlap_ms # Extend end later | |
# Calculate the effective region (excluding overlaps) | |
effective_start_ms = chunk_start_ms | |
effective_end_ms = chunk_end_ms | |
# Don't count overlap in effective region | |
if i > 0: | |
effective_start_ms += overlap_ms | |
if i < total_chunks - 1: | |
effective_end_ms -= overlap_ms | |
# Convert to seconds for logging | |
chunk_start_sec = chunk_start_ms / 1000 | |
chunk_end_sec = chunk_end_ms / 1000 | |
effective_start_sec = effective_start_ms / 1000 | |
effective_end_sec = effective_end_ms / 1000 | |
print(f"Chunk {i+1} boundaries: {chunk_start_sec:.2f}s - {chunk_end_sec:.2f}s") | |
print(f"Chunk {i+1} effective: {effective_start_sec:.2f}s - {effective_end_sec:.2f}s") | |
# Extract chunk | |
chunk = audio[chunk_start_ms:chunk_end_ms] | |
# Save chunk to temporary file | |
chunk_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
chunk.export(chunk_file.name, format="wav") | |
temp_files.append(chunk_file.name) | |
chunk_file.close() | |
try: | |
# Move model to GPU at the latest possible time | |
model.to(device) | |
# Process chunk | |
chunk_duration = (chunk_end_ms - chunk_start_ms) / 1000.0 | |
gr.Info(f"Transcribing chunk {i+1}/{total_chunks} ({chunk_start_sec:.1f}s - {chunk_end_sec:.1f}s, {chunk_duration:.1f}s)...", duration=2) | |
output = model.transcribe([chunk_file.name], timestamps=True) | |
# Move model back to CPU immediately after processing | |
if device == 'cuda': | |
model.cpu() | |
if (output and isinstance(output, list) and output[0] and | |
hasattr(output[0], 'timestamp') and output[0].timestamp and | |
'segment' in output[0].timestamp): | |
chunk_segments = output[0].timestamp['segment'] | |
segments_before = len(all_segments) | |
print(f"Chunk {i+1}: Got {len(chunk_segments)} segments") | |
# Add all segments from this chunk, adjusting timestamps | |
for segment in chunk_segments: | |
# Adjust timestamps to global timeline | |
segment_start = segment['start'] + chunk_start_sec | |
segment_end = segment['end'] + chunk_start_sec | |
# Only keep segments that are mostly within the effective region | |
# Using segment midpoint to determine inclusion | |
segment_midpoint = (segment_start + segment_end) / 2 | |
if effective_start_sec <= segment_midpoint <= effective_end_sec: | |
all_segments.append({ | |
'start': segment_start, | |
'end': segment_end, | |
'segment': segment['segment'] | |
}) | |
print(f"Chunk {i+1}: Added {len(all_segments) - segments_before} segments (total now: {len(all_segments)})") | |
# Clean memory between chunks | |
gc.collect() | |
if device == 'cuda': | |
torch.cuda.empty_cache() | |
except torch.cuda.OutOfMemoryError as oom_e: | |
print(f"CUDA Out of Memory error on chunk {i+1}: {oom_e}") | |
gr.Warning(f"CUDA memory error on chunk {i+1}. Trying to continue with remaining chunks.", duration=3) | |
if device == 'cuda': | |
model.cpu() # Make sure we move back to CPU | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Continue with next chunk | |
except Exception as chunk_e: | |
gr.Warning(f"Error processing chunk {i+1}: {chunk_e}", duration=3) | |
print(f"Error processing chunk {i+1}: {chunk_e}") | |
if device == 'cuda': | |
model.cpu() # Make sure we move back to CPU | |
# Continue with other chunks even if one fails | |
else: | |
# For shorter audio, process the entire file at once | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
audio.export(temp_file.name, format="wav") | |
temp_files.append(temp_file.name) | |
temp_file.close() | |
try: | |
model.to(device) | |
gr.Info(f"Transcribing {original_path_name} on {device}...", duration=2) | |
output = model.transcribe([temp_file.name], timestamps=True) | |
# Move model back to CPU immediately | |
if device == 'cuda': | |
model.cpu() | |
if (not output or not isinstance(output, list) or not output[0] | |
or not hasattr(output[0], 'timestamp') or not output[0].timestamp | |
or 'segment' not in output[0].timestamp): | |
gr.Error("Transcription failed or produced unexpected output format.", duration=None) | |
return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False) | |
chunk_segments = output[0].timestamp['segment'] | |
for segment in chunk_segments: | |
all_segments.append({ | |
'start': segment['start'], | |
'end': segment['end'], | |
'segment': segment['segment'] | |
}) | |
print(f"Single chunk processing: Got {len(all_segments)} segments") | |
except torch.cuda.OutOfMemoryError as e: | |
error_msg = 'CUDA out of memory. The file may be too large for available GPU memory.' | |
print(f"CUDA OutOfMemoryError: {e}") | |
gr.Error(error_msg, duration=None) | |
return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False) | |
except Exception as e: | |
error_msg = f"Transcription failed: {e}" | |
print(f"Error during transcription processing: {e}") | |
gr.Error(error_msg, duration=None) | |
return [["Error", "Error", error_msg]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False) | |
# If we have no segments (all chunks failed) return an error | |
if len(all_segments) == 0: | |
gr.Error("Failed to transcribe any portion of the audio.", duration=None) | |
return [["Error", "Error", "No transcription segments generated"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False) | |
# Debug: print a few segments to check timestamps | |
print(f"All segments: {len(all_segments)}") | |
all_segments.sort(key=lambda x: x['start']) # Ensure chronological order | |
print(f"First segment: {all_segments[0]['start']:.2f}s - {all_segments[0]['end']:.2f}s: {all_segments[0]['segment']}") | |
if len(all_segments) > 1: | |
print(f"Second segment: {all_segments[1]['start']:.2f}s - {all_segments[1]['end']:.2f}s: {all_segments[1]['segment']}") | |
if len(all_segments) > 2: | |
middle_idx = len(all_segments) // 2 | |
print(f"Middle segment: {all_segments[middle_idx]['start']:.2f}s - {all_segments[middle_idx]['end']:.2f}s: {all_segments[middle_idx]['segment']}") | |
print(f"Last segment: {all_segments[-1]['start']:.2f}s - {all_segments[-1]['end']:.2f}s: {all_segments[-1]['segment']}") | |
# Create visualization data | |
vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in all_segments] | |
raw_times_data = [[ts['start'], ts['end']] for ts in all_segments] | |
# Generate SRT with correct timestamps | |
srt_lines = [] | |
for i, ts in enumerate(all_segments, 1): | |
start = seconds_to_srt_ts(ts['start']) | |
end = seconds_to_srt_ts(ts['end']) | |
text = ts['segment'].replace('\n', ' ').strip() | |
srt_lines.append(f"{i}\n{start} --> {end}\n{text}\n") | |
# Save SRT file | |
button_update = gr.DownloadButton(visible=False) | |
try: | |
temp_srt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".srt", mode='w', encoding='utf-8') | |
temp_srt_file.write('\n'.join(srt_lines)) | |
srt_file_path = temp_srt_file.name | |
temp_srt_file.close() | |
print(f"SRT transcript saved to temporary file: {srt_file_path}") | |
button_update = gr.DownloadButton(value=srt_file_path, visible=True, label="Download Subtitle File (.srt)") | |
except Exception as srt_e: | |
gr.Error(f"Failed to create transcript SRT file: {srt_e}", duration=None) | |
print(f"Error writing SRT: {srt_e}") | |
gr.Info(f"Transcription complete! Generated {len(all_segments)} segments.", duration=2) | |
return vis_data, raw_times_data, file_path, button_update | |
finally: | |
# Clean up all temporary files | |
for temp_path in temp_files: | |
if temp_path and os.path.exists(temp_path): | |
try: | |
os.remove(temp_path) | |
print(f"Temporary file {temp_path} removed.") | |
except Exception as e: | |
print(f"Error removing temporary file {temp_path}: {e}") | |
# Final cleanup | |
try: | |
if 'model' in locals() and hasattr(model, 'cpu'): | |
if device == 'cuda': | |
model.cpu() | |
gc.collect() | |
if device == 'cuda': | |
torch.cuda.empty_cache() | |
except Exception as cleanup_e: | |
print(f"Error during model cleanup: {cleanup_e}") | |
gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5) | |
def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path): | |
if not isinstance(raw_ts_list, list): | |
print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.") | |
return gr.Audio(value=None, label="Selected Segment") | |
if not current_audio_path: | |
print("No audio path available to play segment from.") | |
return gr.Audio(value=None, label="Selected Segment") | |
selected_index = evt.index[0] | |
if selected_index < 0 or selected_index >= len(raw_ts_list): | |
print(f"Invalid index {selected_index} selected for list of length {len(raw_ts_list)}.") | |
return gr.Audio(value=None, label="Selected Segment") | |
if (not isinstance(raw_ts_list[selected_index], (list, tuple)) | |
or len(raw_ts_list[selected_index]) != 2): | |
print(f"Warning: Data at index {selected_index} is not in the expected format [start, end].") | |
return gr.Audio(value=None, label="Selected Segment") | |
start_time_s, end_time_s = raw_ts_list[selected_index] | |
print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s") | |
segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s) | |
if segment_data: | |
print("Segment data retrieved successfully.") | |
return gr.Audio(value=segment_data, autoplay=True, label=f"Segment: {start_time_s:.2f}s - {end_time_s:.2f}s", interactive=False) | |
else: | |
print("Failed to get audio segment data.") | |
return gr.Audio(value=None, label="Selected Segment") | |
article = ( | |
"<p style='font-size: 1.1em;'>" | |
"Upload an <b>audio file</b> (wav, mp3, etc) <b>or a video file</b> (mp4, m4a, etc) and this tool will extract the audio stream and generate subtitles in .srt format.<br>" | |
"Files longer than 10 minutes will be automatically split into chunks for processing.</p>" | |
) | |
# NVIDIA-inspired theme | |
nvidia_theme = gr_themes.Default( | |
primary_hue=gr_themes.Color( | |
c50="#E6F1D9", | |
c100="#CEE3B3", | |
c200="#B5D58C", | |
c300="#9CC766", | |
c400="#84B940", | |
c500="#76B900", | |
c600="#68A600", | |
c700="#5A9200", | |
c800="#4C7E00", | |
c900="#3E6A00", | |
c950="#2F5600" | |
), | |
neutral_hue="gray", | |
font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], | |
).set() | |
with gr.Blocks(theme=nvidia_theme) as demo: | |
model_display_name = MODEL_NAME.split('/')[-1] if '/' in MODEL_NAME else MODEL_NAME | |
gr.Markdown(f"<h1 style='text-align: center; margin: 0 auto;'>Subtitle Generation (en) with {model_display_name}</h1>") | |
gr.HTML(article) | |
current_audio_path_state = gr.State(None) | |
raw_timestamps_list_state = gr.State([]) | |
# Use gr.File instead of gr.Audio to accept video files | |
file_input = gr.File( | |
label="Upload Audio or Video File (MP3, WAV, MP4, etc)", | |
file_types=[".mp3", ".wav", ".mp4", ".m4a", ".aac", ".ogg", ".flac", ".mov", ".mkv"], | |
) | |
file_transcribe_btn = gr.Button("Transcribe Uploaded File", variant="primary") | |
gr.Markdown("---") | |
gr.Markdown("<p><strong style='color: #FF0000; font-size: 1.2em;'>Transcription Results (Click row to play segment)</strong></p>") | |
download_btn = gr.DownloadButton(label="Download Subtitle File (.srt)", visible=False) | |
vis_timestamps_df = gr.DataFrame( | |
headers=["Start (s)", "End (s)", "Segment"], | |
datatype=["number", "number", "str"], | |
wrap=True, | |
label="Transcription Segments" | |
) | |
selected_segment_player = gr.Audio(label="Selected Segment", interactive=False) | |
file_transcribe_btn.click( | |
fn=get_transcripts_and_raw_times, | |
inputs=[file_input], | |
outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn], | |
api_name="transcribe_file" | |
) | |
vis_timestamps_df.select( | |
fn=play_segment, | |
inputs=[raw_timestamps_list_state, current_audio_path_state], | |
outputs=[selected_segment_player], | |
) | |
if __name__ == "__main__": | |
print("Launching Gradio Demo...") | |
demo.queue() | |
demo.launch(server_name="0.0.0.0") | |