Spaces:
Runtime error
Runtime error
import nemo.collections.asr.models as nemo_asr_models | |
import torch | |
import gradio as gr | |
import spaces | |
import gc | |
import shutil | |
from pathlib import Path | |
from pydub import AudioSegment | |
import numpy as np | |
import os | |
import gradio.themes as gr_themes | |
import csv | |
import json | |
import time | |
import math | |
# --- Global Settings --- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2" | |
# --- Model Loading (cached) --- | |
def load_model(): | |
print(f"Loading ASR model: {MODEL_NAME} on {device}...") | |
try: | |
model = nemo_asr_models.ASRModel.from_pretrained(model_name=MODEL_NAME) | |
model.to(device) | |
model.eval() | |
print("ASR Model loaded successfully.") | |
return model | |
except Exception as e: | |
print(f"Error loading ASR model: {e}") | |
raise gr.Error(f"Failed to load ASR model: {MODEL_NAME}. Check logs for details.") | |
ASR_MODEL = load_model() | |
# --- Helper Functions --- | |
def create_session_dir(request: gr.Request): | |
session_hash = request.session_hash | |
session_dir = Path(f'/tmp/gradio/{session_hash}') # Gradio's default tmp path might be better | |
session_dir.mkdir(parents=True, exist_ok=True) | |
print(f"Session directory created: {session_dir}") | |
return session_dir | |
def cleanup_session_dir(session_dir: Path): | |
if session_dir.exists(): | |
try: | |
shutil.rmtree(session_dir) | |
print(f"Session directory cleaned: {session_dir}") | |
except Exception as e: | |
print(f"Error cleaning session directory {session_dir}: {e}") | |
def get_audio_duration(audio_path): | |
try: | |
audio = AudioSegment.from_file(audio_path) | |
return audio.duration_seconds | |
except Exception as e: | |
print(f"Error getting audio duration for {audio_path}: {e}") | |
return 0 | |
def format_timestamp(seconds): | |
"""Converts seconds to HH:MM:SS.mmm format.""" | |
ms = int((seconds - int(seconds)) * 1000) | |
s = int(seconds) % 60 | |
m = int(seconds // 60) % 60 | |
h = int(seconds // 3600) | |
return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}" | |
# --- Core Transcription Logic --- | |
def transcribe_audio_long(audio_path, session_dir_path_str, chunk_duration_min=15, overlap_sec=5, progress=gr.Progress(track_tqdm=True)): | |
if not audio_path: | |
raise gr.Error("No audio file provided!") | |
if not ASR_MODEL: | |
raise gr.Error("ASR Model not loaded. Please check server logs.") | |
session_dir = Path(session_dir_path_str) | |
chunk_duration_ms = chunk_duration_min * 60 * 1000 | |
overlap_ms = overlap_sec * 1000 | |
try: | |
gr.Info(f"Loading audio file: {Path(audio_path).name}") | |
original_audio = AudioSegment.from_file(audio_path) | |
original_duration_sec = original_audio.duration_seconds | |
print(f"Original audio duration: {original_duration_sec} seconds") | |
# Resample and convert to mono if necessary (as in original Space) | |
target_sr = 16000 | |
processed_audio = original_audio | |
if original_audio.frame_rate != target_sr: | |
gr.Info(f"Resampling audio to {target_sr}Hz...") | |
processed_audio = processed_audio.set_frame_rate(target_sr) | |
if original_audio.channels != 1: | |
gr.Info("Converting audio to mono...") | |
processed_audio = processed_audio.set_channels(1) | |
except Exception as e: | |
print(f"Error loading/preprocessing audio: {e}") | |
raise gr.Error(f"Failed to load or preprocess audio: {e}") | |
# Chunking logic | |
chunks = [] | |
current_pos_ms = 0 | |
chunk_id = 0 | |
while current_pos_ms < len(processed_audio): | |
end_pos_ms = current_pos_ms + chunk_duration_ms | |
chunk = processed_audio[current_pos_ms:end_pos_ms] | |
chunk_file_path = session_dir / f"chunk_{chunk_id}.wav" | |
chunk.export(chunk_file_path, format="wav") | |
chunks.append({"id": chunk_id, "path": str(chunk_file_path), "start_ms": current_pos_ms, "end_ms": min(end_pos_ms, len(processed_audio))}) | |
print(f"Created chunk {chunk_id}: {chunk_file_path} ({(current_pos_ms/1000):.2f}s - {min(end_pos_ms, len(processed_audio))/1000:.2f}s)") | |
current_pos_ms += chunk_duration_ms - overlap_ms | |
if current_pos_ms >= len(processed_audio) - overlap_ms and end_pos_ms < len(processed_audio): | |
# Ensure the last part is captured if it's shorter than a full chunk but longer than overlap | |
current_pos_ms = end_pos_ms - chunk_duration_ms # force last chunk to be full if possible | |
if current_pos_ms < end_pos_ms - overlap_ms : # if the previous step made it too small | |
current_pos_ms = max(0, len(processed_audio) - chunk_duration_ms) # last chunk starts such that it ends at audio end | |
chunk_id += 1 | |
if end_pos_ms >= len(processed_audio): | |
break | |
if not chunks: | |
raise gr.Error("Audio too short to be chunked or chunking failed.") | |
all_word_timestamps = [] | |
full_transcript_text = "" | |
global_offset_ms = 0 | |
# Apply long audio settings to the model (if not already applied or if model was reloaded) | |
# These settings are generally for the Nemo model's internal handling of longer sequences within a chunk. | |
try: | |
ASR_MODEL.change_attention_model("rel_pos_local_attn", [256,256]) | |
ASR_MODEL.change_subsampling_conv_chunking_factor(1) | |
print("Applied long audio settings to ASR model.") | |
except Exception as e: | |
print(f"Warning: Could not apply all long audio settings to model: {e}") | |
progress(0, desc="Starting transcription...") | |
for i, chunk_info in enumerate(chunks): | |
gr.Info(f"Transcribing chunk {i+1}/{len(chunks)}...") | |
print(f"Transcribing chunk {chunk_info['id']} at {chunk_info['path']}") | |
try: | |
hypotheses = ASR_MODEL.transcribe([chunk_info["path"]], timestamps=True, batch_size=1) # Batch size 1 for simplicity with timestamps | |
if hypotheses and hypotheses[0] and hasattr(hypotheses[0], 'timestamp') and hypotheses[0].timestamp.get('word'): | |
chunk_word_timestamps = hypotheses[0].timestamp['word'] | |
chunk_text = hypotheses[0].text | |
print(f"Chunk {chunk_info['id']} text: {chunk_text}") | |
# Adjust timestamps to global time and handle overlap | |
# The first chunk is taken as is. For subsequent chunks, find the overlap point. | |
current_chunk_global_start_ms = chunk_info['start_ms'] | |
if not all_word_timestamps: # First chunk | |
for ts in chunk_word_timestamps: | |
ts['start'] += current_chunk_global_start_ms / 1000.0 | |
ts['end'] += current_chunk_global_start_ms / 1000.0 | |
all_word_timestamps.append(ts) | |
full_transcript_text = chunk_text | |
else: | |
# Determine where the new chunk's non-overlapping part begins | |
# This is based on the previous chunk's actual end time in the original audio | |
# minus the overlap, which is where this chunk effectively starts contributing new info. | |
previous_chunk_effective_end_ms = chunks[i-1]['end_ms'] - overlap_ms | |
# Find the first word in the current chunk that starts AFTER the overlap with the previous chunk ends | |
# The timestamps from NeMo are relative to the start of the current CHUNK. | |
# So, a word's start_time (in sec) * 1000 + current_chunk_global_start_ms must be >= previous_chunk_effective_end_ms | |
first_new_word_index = 0 | |
for k, ts in enumerate(chunk_word_timestamps): | |
word_global_start_ms = (ts['start'] * 1000) + current_chunk_global_start_ms | |
if word_global_start_ms >= previous_chunk_effective_end_ms: | |
first_new_word_index = k | |
break | |
# If all words are within overlap, take the last one to avoid losing too much | |
if k == len(chunk_word_timestamps) -1: | |
first_new_word_index = k | |
# Add words from the non-overlapping part | |
appended_text_segment = "" | |
for k in range(first_new_word_index, len(chunk_word_timestamps)): | |
ts = chunk_word_timestamps[k] | |
ts['start'] += current_chunk_global_start_ms / 1000.0 | |
ts['end'] += current_chunk_global_start_ms / 1000.0 | |
# Avoid duplicate words if timestamps are very close due to overlap logic | |
if not all_word_timestamps or \ | |
(abs(all_word_timestamps[-1]['end'] - ts['start']) > 0.01 and all_word_timestamps[-1]['word'] != ts['word']): | |
all_word_timestamps.append(ts) | |
appended_text_segment += ((" " if appended_text_segment else "") + ts['word']) | |
if appended_text_segment: | |
full_transcript_text += (" " + appended_text_segment) | |
else: | |
print(f"Warning: No word timestamps found for chunk {chunk_info['id']}. Text: {hypotheses[0].text if hypotheses and hypotheses[0] else 'N/A'}") | |
except Exception as e: | |
print(f"Error transcribing chunk {chunk_info['id']}: {e}") | |
# Optionally, allow continuing with next chunk or raise error | |
# raise gr.Error(f"Error during transcription of chunk {chunk_info['id']}: {e}") | |
finally: | |
# Clean up individual chunk file | |
if Path(chunk_info["path"]).exists(): | |
try: | |
os.remove(chunk_info["path"]) | |
except Exception as e_del: | |
print(f"Could not delete chunk file {chunk_info['path']}: {e_del}") | |
progress((i + 1) / len(chunks), desc=f"Transcribed chunk {i+1}/{len(chunks)}") | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Prepare output formats | |
# SRT format | |
srt_content = "" | |
for i, ts_word_group in enumerate(all_word_timestamps): # Assuming all_word_timestamps is now a list of words | |
# For SRT, it's better to group words into segments. Here, we'll do one word per segment for simplicity. | |
# A more advanced approach would group words based on pauses or sentence structure from segment_timestamps if available and reliable. | |
start_time_srt = format_timestamp(ts_word_group['start']) | |
end_time_srt = format_timestamp(ts_word_group['end']) | |
srt_content += f"{i+1}\n{start_time_srt} --> {end_time_srt}\n{ts_word_group['word']}\n\n" | |
# VTT format (similar to SRT but with WebVTT header) | |
vtt_content = "WEBVTT\n\n" | |
for i, ts_word_group in enumerate(all_word_timestamps): | |
start_time_vtt = format_timestamp(ts_word_group['start']) | |
end_time_vtt = format_timestamp(ts_word_group['end']) | |
vtt_content += f"{start_time_vtt} --> {end_time_vtt}\n{ts_word_group['word']}\n\n" | |
# Raw text output | |
raw_text_path = session_dir / "full_transcript.txt" | |
with open(raw_text_path, "w", encoding="utf-8") as f: | |
f.write(full_transcript_text) | |
srt_file_path = session_dir / "full_transcript.srt" | |
with open(srt_file_path, "w", encoding="utf-8") as f: | |
f.write(srt_content) | |
vtt_file_path = session_dir / "full_transcript.vtt" | |
with open(vtt_file_path, "w", encoding="utf-8") as f: | |
f.write(vtt_content) | |
# Gradio DataFrame for word timestamps | |
word_vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['word']] for ts in all_word_timestamps] | |
gr.Info("Transcription complete!") | |
return ( | |
full_transcript_text, | |
gr.DataFrame(value=word_vis_data, headers=["Start (s)", "End (s)", "Word"]), | |
gr.File(value=str(raw_text_path), visible=True, label="Download TXT"), | |
gr.File(value=str(srt_file_path), visible=True, label="Download SRT"), | |
gr.File(value=str(vtt_file_path), visible=True, label="Download VTT") | |
) | |
# --- Gradio Interface --- | |
def main_interface(audio_file, request: gr.Request): | |
session_dir = create_session_dir(request) | |
try: | |
if audio_file is None: | |
raise gr.Error("Please upload an audio file.") | |
audio_path = audio_file.name # Path to the uploaded file in Gradio's temp space | |
# Call the long audio transcription function | |
transcript, word_df, txt_file, srt_file, vtt_file = transcribe_audio_long(audio_path, str(session_dir)) | |
return transcript, word_df, txt_file, srt_file, vtt_file | |
except Exception as e: | |
print(f"Error in main_interface: {e}") | |
# Ensure some valid return structure for Gradio even on error | |
return str(e), gr.DataFrame(value=None), gr.File(visible=False), gr.File(visible=False), gr.File(visible=False) | |
finally: | |
# It's tricky to clean up session_dir immediately if files are being served by Gradio for download. | |
# Gradio usually handles its own temp files. If we create a custom session_dir, | |
# we might need a separate mechanism or rely on Space's ephemeral filesystem. | |
# For now, let's comment out direct cleanup here, assuming Gradio's tmp or Space cleanup handles it. | |
# cleanup_session_dir(session_dir) | |
pass | |
# Theme and Blocks | |
theme = gr.themes.Base( | |
primary_hue=gr.themes.colors.blue, | |
secondary_hue=gr.themes.colors.sky, | |
).set( | |
button_primary_background_fill="*primary_500", | |
button_primary_background_fill_hover="*primary_400", | |
button_primary_text_color="white", | |
) | |
css = """ | |
.gradio-container { max-width: 980px !important; margin: auto !important; } | |
footer { display: none !important; } | |
""" | |
with gr.Blocks(theme=theme, css=css, title="Parakeet TDT Long Audio Transcription") as demo: | |
gr.Markdown(""" | |
# 🦜 Parakeet TDT - Long Audio Transcription | |
Transcribe long audio files (up to several hours) using NVIDIA Parakeet-TDT-0.6B-v2. | |
The audio is automatically chunked with overlap for robust transcription. | |
Default chunk size: 15 minutes, Overlap: 5 seconds. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
audio_input = gr.Audio(type="filepath", label="Upload Long Audio File (WAV, MP3, FLAC, etc.)") | |
# Advanced options (could be hidden by default) | |
# chunk_min_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Chunk Duration (minutes)") | |
# overlap_sec_slider = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Overlap (seconds)") | |
transcribe_button = gr.Button("Transcribe Audio", variant="primary") | |
with gr.Column(scale=2): | |
output_text = gr.Textbox(label="Full Transcription", lines=15, interactive=False) | |
output_word_df = gr.DataFrame(label="Word Timestamps", headers=["Start (s)", "End (s)", "Word"], interactive=False) | |
with gr.Row(): | |
download_txt_btn = gr.File(label="Download TXT", interactive=False, visible=True) | |
download_srt_btn = gr.File(label="Download SRT", interactive=False, visible=True) | |
download_vtt_btn = gr.File(label="Download VTT", interactive=False, visible=True) | |
transcribe_button.click( | |
fn=main_interface, | |
inputs=[audio_input], | |
outputs=[output_text, output_word_df, download_txt_btn, download_srt_btn, download_vtt_btn], | |
api_name="transcribe_long_audio" | |
) | |
gr.Markdown(""" | |
--- | |
Powered by [NVIDIA NeMo Parakeet-TDT](https_://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/parakeet_tdt_0.6b_v2) and Gradio. | |
Note: Transcription of very long files can take a significant amount of time and resources. | |
""") | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True, share=False) # Set share=True to create a public link if running locally | |