sungo-ganpare's picture
Update app.py
7dd9fc9 verified
import nemo.collections.asr.models as nemo_asr_models
import torch
import gradio as gr
import spaces
import gc
import shutil
from pathlib import Path
from pydub import AudioSegment
import numpy as np
import os
import gradio.themes as gr_themes
import csv
import json
import time
import math
# --- Global Settings ---
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
# --- Model Loading (cached) ---
@spaces.GPU(enable_queue=True)
@torch.no_grad()
def load_model():
print(f"Loading ASR model: {MODEL_NAME} on {device}...")
try:
model = nemo_asr_models.ASRModel.from_pretrained(model_name=MODEL_NAME)
model.to(device)
model.eval()
print("ASR Model loaded successfully.")
return model
except Exception as e:
print(f"Error loading ASR model: {e}")
raise gr.Error(f"Failed to load ASR model: {MODEL_NAME}. Check logs for details.")
ASR_MODEL = load_model()
# --- Helper Functions ---
def create_session_dir(request: gr.Request):
session_hash = request.session_hash
session_dir = Path(f'/tmp/gradio/{session_hash}') # Gradio's default tmp path might be better
session_dir.mkdir(parents=True, exist_ok=True)
print(f"Session directory created: {session_dir}")
return session_dir
def cleanup_session_dir(session_dir: Path):
if session_dir.exists():
try:
shutil.rmtree(session_dir)
print(f"Session directory cleaned: {session_dir}")
except Exception as e:
print(f"Error cleaning session directory {session_dir}: {e}")
def get_audio_duration(audio_path):
try:
audio = AudioSegment.from_file(audio_path)
return audio.duration_seconds
except Exception as e:
print(f"Error getting audio duration for {audio_path}: {e}")
return 0
def format_timestamp(seconds):
"""Converts seconds to HH:MM:SS.mmm format."""
ms = int((seconds - int(seconds)) * 1000)
s = int(seconds) % 60
m = int(seconds // 60) % 60
h = int(seconds // 3600)
return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}"
# --- Core Transcription Logic ---
@spaces.GPU(enable_queue=True)
@torch.no_grad()
def transcribe_audio_long(audio_path, session_dir_path_str, chunk_duration_min=15, overlap_sec=5, progress=gr.Progress(track_tqdm=True)):
if not audio_path:
raise gr.Error("No audio file provided!")
if not ASR_MODEL:
raise gr.Error("ASR Model not loaded. Please check server logs.")
session_dir = Path(session_dir_path_str)
chunk_duration_ms = chunk_duration_min * 60 * 1000
overlap_ms = overlap_sec * 1000
try:
gr.Info(f"Loading audio file: {Path(audio_path).name}")
original_audio = AudioSegment.from_file(audio_path)
original_duration_sec = original_audio.duration_seconds
print(f"Original audio duration: {original_duration_sec} seconds")
# Resample and convert to mono if necessary (as in original Space)
target_sr = 16000
processed_audio = original_audio
if original_audio.frame_rate != target_sr:
gr.Info(f"Resampling audio to {target_sr}Hz...")
processed_audio = processed_audio.set_frame_rate(target_sr)
if original_audio.channels != 1:
gr.Info("Converting audio to mono...")
processed_audio = processed_audio.set_channels(1)
except Exception as e:
print(f"Error loading/preprocessing audio: {e}")
raise gr.Error(f"Failed to load or preprocess audio: {e}")
# Chunking logic
chunks = []
current_pos_ms = 0
chunk_id = 0
while current_pos_ms < len(processed_audio):
end_pos_ms = current_pos_ms + chunk_duration_ms
chunk = processed_audio[current_pos_ms:end_pos_ms]
chunk_file_path = session_dir / f"chunk_{chunk_id}.wav"
chunk.export(chunk_file_path, format="wav")
chunks.append({"id": chunk_id, "path": str(chunk_file_path), "start_ms": current_pos_ms, "end_ms": min(end_pos_ms, len(processed_audio))})
print(f"Created chunk {chunk_id}: {chunk_file_path} ({(current_pos_ms/1000):.2f}s - {min(end_pos_ms, len(processed_audio))/1000:.2f}s)")
current_pos_ms += chunk_duration_ms - overlap_ms
if current_pos_ms >= len(processed_audio) - overlap_ms and end_pos_ms < len(processed_audio):
# Ensure the last part is captured if it's shorter than a full chunk but longer than overlap
current_pos_ms = end_pos_ms - chunk_duration_ms # force last chunk to be full if possible
if current_pos_ms < end_pos_ms - overlap_ms : # if the previous step made it too small
current_pos_ms = max(0, len(processed_audio) - chunk_duration_ms) # last chunk starts such that it ends at audio end
chunk_id += 1
if end_pos_ms >= len(processed_audio):
break
if not chunks:
raise gr.Error("Audio too short to be chunked or chunking failed.")
all_word_timestamps = []
full_transcript_text = ""
global_offset_ms = 0
# Apply long audio settings to the model (if not already applied or if model was reloaded)
# These settings are generally for the Nemo model's internal handling of longer sequences within a chunk.
try:
ASR_MODEL.change_attention_model("rel_pos_local_attn", [256,256])
ASR_MODEL.change_subsampling_conv_chunking_factor(1)
print("Applied long audio settings to ASR model.")
except Exception as e:
print(f"Warning: Could not apply all long audio settings to model: {e}")
progress(0, desc="Starting transcription...")
for i, chunk_info in enumerate(chunks):
gr.Info(f"Transcribing chunk {i+1}/{len(chunks)}...")
print(f"Transcribing chunk {chunk_info['id']} at {chunk_info['path']}")
try:
hypotheses = ASR_MODEL.transcribe([chunk_info["path"]], timestamps=True, batch_size=1) # Batch size 1 for simplicity with timestamps
if hypotheses and hypotheses[0] and hasattr(hypotheses[0], 'timestamp') and hypotheses[0].timestamp.get('word'):
chunk_word_timestamps = hypotheses[0].timestamp['word']
chunk_text = hypotheses[0].text
print(f"Chunk {chunk_info['id']} text: {chunk_text}")
# Adjust timestamps to global time and handle overlap
# The first chunk is taken as is. For subsequent chunks, find the overlap point.
current_chunk_global_start_ms = chunk_info['start_ms']
if not all_word_timestamps: # First chunk
for ts in chunk_word_timestamps:
ts['start'] += current_chunk_global_start_ms / 1000.0
ts['end'] += current_chunk_global_start_ms / 1000.0
all_word_timestamps.append(ts)
full_transcript_text = chunk_text
else:
# Determine where the new chunk's non-overlapping part begins
# This is based on the previous chunk's actual end time in the original audio
# minus the overlap, which is where this chunk effectively starts contributing new info.
previous_chunk_effective_end_ms = chunks[i-1]['end_ms'] - overlap_ms
# Find the first word in the current chunk that starts AFTER the overlap with the previous chunk ends
# The timestamps from NeMo are relative to the start of the current CHUNK.
# So, a word's start_time (in sec) * 1000 + current_chunk_global_start_ms must be >= previous_chunk_effective_end_ms
first_new_word_index = 0
for k, ts in enumerate(chunk_word_timestamps):
word_global_start_ms = (ts['start'] * 1000) + current_chunk_global_start_ms
if word_global_start_ms >= previous_chunk_effective_end_ms:
first_new_word_index = k
break
# If all words are within overlap, take the last one to avoid losing too much
if k == len(chunk_word_timestamps) -1:
first_new_word_index = k
# Add words from the non-overlapping part
appended_text_segment = ""
for k in range(first_new_word_index, len(chunk_word_timestamps)):
ts = chunk_word_timestamps[k]
ts['start'] += current_chunk_global_start_ms / 1000.0
ts['end'] += current_chunk_global_start_ms / 1000.0
# Avoid duplicate words if timestamps are very close due to overlap logic
if not all_word_timestamps or \
(abs(all_word_timestamps[-1]['end'] - ts['start']) > 0.01 and all_word_timestamps[-1]['word'] != ts['word']):
all_word_timestamps.append(ts)
appended_text_segment += ((" " if appended_text_segment else "") + ts['word'])
if appended_text_segment:
full_transcript_text += (" " + appended_text_segment)
else:
print(f"Warning: No word timestamps found for chunk {chunk_info['id']}. Text: {hypotheses[0].text if hypotheses and hypotheses[0] else 'N/A'}")
except Exception as e:
print(f"Error transcribing chunk {chunk_info['id']}: {e}")
# Optionally, allow continuing with next chunk or raise error
# raise gr.Error(f"Error during transcription of chunk {chunk_info['id']}: {e}")
finally:
# Clean up individual chunk file
if Path(chunk_info["path"]).exists():
try:
os.remove(chunk_info["path"])
except Exception as e_del:
print(f"Could not delete chunk file {chunk_info['path']}: {e_del}")
progress((i + 1) / len(chunks), desc=f"Transcribed chunk {i+1}/{len(chunks)}")
gc.collect()
torch.cuda.empty_cache()
# Prepare output formats
# SRT format
srt_content = ""
for i, ts_word_group in enumerate(all_word_timestamps): # Assuming all_word_timestamps is now a list of words
# For SRT, it's better to group words into segments. Here, we'll do one word per segment for simplicity.
# A more advanced approach would group words based on pauses or sentence structure from segment_timestamps if available and reliable.
start_time_srt = format_timestamp(ts_word_group['start'])
end_time_srt = format_timestamp(ts_word_group['end'])
srt_content += f"{i+1}\n{start_time_srt} --> {end_time_srt}\n{ts_word_group['word']}\n\n"
# VTT format (similar to SRT but with WebVTT header)
vtt_content = "WEBVTT\n\n"
for i, ts_word_group in enumerate(all_word_timestamps):
start_time_vtt = format_timestamp(ts_word_group['start'])
end_time_vtt = format_timestamp(ts_word_group['end'])
vtt_content += f"{start_time_vtt} --> {end_time_vtt}\n{ts_word_group['word']}\n\n"
# Raw text output
raw_text_path = session_dir / "full_transcript.txt"
with open(raw_text_path, "w", encoding="utf-8") as f:
f.write(full_transcript_text)
srt_file_path = session_dir / "full_transcript.srt"
with open(srt_file_path, "w", encoding="utf-8") as f:
f.write(srt_content)
vtt_file_path = session_dir / "full_transcript.vtt"
with open(vtt_file_path, "w", encoding="utf-8") as f:
f.write(vtt_content)
# Gradio DataFrame for word timestamps
word_vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['word']] for ts in all_word_timestamps]
gr.Info("Transcription complete!")
return (
full_transcript_text,
gr.DataFrame(value=word_vis_data, headers=["Start (s)", "End (s)", "Word"]),
gr.File(value=str(raw_text_path), visible=True, label="Download TXT"),
gr.File(value=str(srt_file_path), visible=True, label="Download SRT"),
gr.File(value=str(vtt_file_path), visible=True, label="Download VTT")
)
# --- Gradio Interface ---
def main_interface(audio_file, request: gr.Request):
session_dir = create_session_dir(request)
try:
if audio_file is None:
raise gr.Error("Please upload an audio file.")
audio_path = audio_file.name # Path to the uploaded file in Gradio's temp space
# Call the long audio transcription function
transcript, word_df, txt_file, srt_file, vtt_file = transcribe_audio_long(audio_path, str(session_dir))
return transcript, word_df, txt_file, srt_file, vtt_file
except Exception as e:
print(f"Error in main_interface: {e}")
# Ensure some valid return structure for Gradio even on error
return str(e), gr.DataFrame(value=None), gr.File(visible=False), gr.File(visible=False), gr.File(visible=False)
finally:
# It's tricky to clean up session_dir immediately if files are being served by Gradio for download.
# Gradio usually handles its own temp files. If we create a custom session_dir,
# we might need a separate mechanism or rely on Space's ephemeral filesystem.
# For now, let's comment out direct cleanup here, assuming Gradio's tmp or Space cleanup handles it.
# cleanup_session_dir(session_dir)
pass
# Theme and Blocks
theme = gr.themes.Base(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.sky,
).set(
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_400",
button_primary_text_color="white",
)
css = """
.gradio-container { max-width: 980px !important; margin: auto !important; }
footer { display: none !important; }
"""
with gr.Blocks(theme=theme, css=css, title="Parakeet TDT Long Audio Transcription") as demo:
gr.Markdown("""
# 🦜 Parakeet TDT - Long Audio Transcription
Transcribe long audio files (up to several hours) using NVIDIA Parakeet-TDT-0.6B-v2.
The audio is automatically chunked with overlap for robust transcription.
Default chunk size: 15 minutes, Overlap: 5 seconds.
""")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(type="filepath", label="Upload Long Audio File (WAV, MP3, FLAC, etc.)")
# Advanced options (could be hidden by default)
# chunk_min_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Chunk Duration (minutes)")
# overlap_sec_slider = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Overlap (seconds)")
transcribe_button = gr.Button("Transcribe Audio", variant="primary")
with gr.Column(scale=2):
output_text = gr.Textbox(label="Full Transcription", lines=15, interactive=False)
output_word_df = gr.DataFrame(label="Word Timestamps", headers=["Start (s)", "End (s)", "Word"], interactive=False)
with gr.Row():
download_txt_btn = gr.File(label="Download TXT", interactive=False, visible=True)
download_srt_btn = gr.File(label="Download SRT", interactive=False, visible=True)
download_vtt_btn = gr.File(label="Download VTT", interactive=False, visible=True)
transcribe_button.click(
fn=main_interface,
inputs=[audio_input],
outputs=[output_text, output_word_df, download_txt_btn, download_srt_btn, download_vtt_btn],
api_name="transcribe_long_audio"
)
gr.Markdown("""
---
Powered by [NVIDIA NeMo Parakeet-TDT](https_://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/parakeet_tdt_0.6b_v2) and Gradio.
Note: Transcription of very long files can take a significant amount of time and resources.
""")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=False) # Set share=True to create a public link if running locally