Spaces:
Running
Running
import spaces | |
import torch | |
import gradio as gr | |
from transformers import pipeline | |
import subprocess | |
from loguru import logger | |
import datetime | |
import tempfile | |
import os | |
import json | |
from pathlib import Path | |
# Configure loguru | |
logger.add("app.log", rotation="500 MB", level="DEBUG") | |
MODEL_NAME = "muhtasham/whisper-tg" | |
def format_time(seconds): | |
"""Convert seconds to SRT time format (HH:MM:SS,mmm)""" | |
td = datetime.timedelta(seconds=float(seconds)) | |
hours = td.seconds // 3600 | |
minutes = (td.seconds % 3600) // 60 | |
seconds = td.seconds % 60 | |
milliseconds = td.microseconds // 1000 | |
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" | |
def generate_srt(chunks): | |
"""Generate SRT format subtitles from chunks""" | |
srt_content = [] | |
for i, chunk in enumerate(chunks, 1): | |
start_time = format_time(chunk["timestamp"][0]) | |
end_time = format_time(chunk["timestamp"][1]) | |
text = chunk["text"].strip() | |
srt_content.append(f"{i}\n{start_time} --> {end_time}\n{text}\n\n") | |
return "".join(srt_content) | |
def save_srt_to_file(srt_content): | |
"""Save SRT content to a temporary file and return the file path""" | |
if not srt_content: | |
return None | |
# Create a temporary file with .srt extension | |
temp_file = tempfile.NamedTemporaryFile(suffix='.srt', delete=False) | |
temp_file.write(srt_content.encode('utf-8')) | |
temp_file.close() | |
return temp_file.name | |
# Check if ffmpeg is installed | |
def check_ffmpeg(): | |
try: | |
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True) | |
logger.info("ffmpeg check passed successfully") | |
except (subprocess.CalledProcessError, FileNotFoundError) as e: | |
logger.error(f"ffmpeg check failed: {str(e)}") | |
raise gr.Error("ffmpeg is not installed. Please install ffmpeg to use this application.") | |
# Initialize ffmpeg check | |
check_ffmpeg() | |
device = 0 if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
def create_pipeline(chunk_length_s): | |
"""Create a new pipeline with specified chunk length""" | |
return pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_NAME, | |
chunk_length_s=chunk_length_s, | |
device=device, | |
) | |
# Initialize default pipeline | |
pipe = create_pipeline(30) | |
logger.info(f"Pipeline initialized: {pipe}") | |
def transcribe(inputs, return_timestamps, generate_subs, batch_size, chunk_length_s): | |
if inputs is None: | |
logger.warning("No audio file submitted") | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
try: | |
logger.info(f"Processing audio file: {inputs}") | |
# Create new pipeline with specified chunk length | |
current_pipe = create_pipeline(chunk_length_s) | |
result = current_pipe(inputs, batch_size=batch_size, return_timestamps=return_timestamps) | |
logger.debug(f"Pipeline result: {result}") | |
# Format response as JSON | |
formatted_result = { | |
"text": result["text"] | |
} | |
chunks = [] | |
if return_timestamps and "chunks" in result: | |
logger.info(f"Processing {len(result['chunks'])} chunks") | |
for i, chunk in enumerate(result["chunks"]): | |
logger.debug(f"Processing chunk {i}: {chunk}") | |
try: | |
start_time = chunk.get("timestamp", [None, None])[0] | |
end_time = chunk.get("timestamp", [None, None])[1] | |
text = chunk.get("text", "").strip() | |
if start_time is not None and end_time is not None: | |
chunk_data = { | |
"text": text, | |
"timestamp": [start_time, end_time] | |
} | |
chunks.append(chunk_data) | |
else: | |
logger.warning(f"Invalid timestamp in chunk {i}: {chunk}") | |
except Exception as chunk_error: | |
logger.error(f"Error processing chunk {i}: {str(chunk_error)}") | |
continue | |
formatted_result["chunks"] = chunks | |
logger.info(f"Successfully processed transcription with {len(chunks)} chunks") | |
# Generate subtitles if requested | |
srt_file = None | |
if generate_subs and chunks: | |
logger.info("Generating SRT subtitles") | |
srt_content = generate_srt(chunks) | |
srt_file = save_srt_to_file(srt_content) | |
logger.info("SRT subtitles generated successfully") | |
return formatted_result, srt_file | |
except Exception as e: | |
logger.exception(f"Error during transcription: {str(e)}") | |
raise gr.Error(f"Failed to transcribe audio: {str(e)}") | |
# Create a custom flagging callback | |
class TranscriptionFlaggingCallback(gr.FlaggingCallback): | |
def __init__(self, flagging_dir): | |
self.flagging_dir = Path(flagging_dir) | |
self.flagging_dir.mkdir(exist_ok=True) | |
self.log_file = self.flagging_dir / "flagged_data.jsonl" | |
def setup(self, components, flagging_dir): | |
pass | |
def flag(self, components, flag_data, flag_option, username): | |
# Create a unique filename for the audio file | |
audio_file = components[0] # First component is the audio input | |
if audio_file: | |
audio_filename = os.path.basename(audio_file) | |
# Copy audio file to flagged directory | |
audio_dir = self.flagging_dir / "audio" | |
audio_dir.mkdir(exist_ok=True) | |
import shutil | |
shutil.copy2(audio_file, audio_dir / audio_filename) | |
else: | |
audio_filename = None | |
# Prepare the data to save | |
data = { | |
"timestamp": datetime.datetime.now().isoformat(), | |
"audio_file": audio_filename, | |
"transcription": components[1], # JSON output | |
"feedback": flag_option, | |
"correction": components[2] if len(components) > 2 else None, # Correction text if provided | |
"username": username | |
} | |
# Append to JSONL file | |
with open(self.log_file, "a", encoding="utf-8") as f: | |
f.write(json.dumps(data) + "\n") | |
logger.info(f"Saved flagged data: {data}") | |
demo = gr.Blocks(theme=gr.themes.Ocean()) | |
# Create flagging callback | |
flagging_callback = TranscriptionFlaggingCallback("flagged_data") | |
# Define interfaces first | |
mf_transcribe = gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.Audio(sources="microphone", type="filepath"), | |
gr.Checkbox(label="Include timestamps", value=True), | |
gr.Checkbox(label="Generate subtitles", value=True), | |
gr.Slider(minimum=1, maximum=128, value=8, step=1, label="Batch Size"), | |
gr.Slider(minimum=5, maximum=30, value=15, step=5, label="Chunk Length (seconds)"), | |
], | |
outputs=[ | |
gr.JSON(label="Transcription", open=True), | |
gr.File(label="Subtitles (SRT)", visible=True), | |
gr.Textbox(label="Correction", visible=False), # Hidden correction input | |
], | |
title="Whisper Large V3 Turbo: Transcribe Audio", | |
description=( | |
"Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the" | |
f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files" | |
" of arbitrary length." | |
), | |
flagging_mode="manual", | |
flagging_options=["👍 Good", "👎 Bad"], | |
flagging_dir="flagged_data", | |
flagging_callback=flagging_callback | |
) | |
file_transcribe = gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.Audio(sources="upload", type="filepath", label="Audio file"), | |
gr.Checkbox(label="Include timestamps", value=True), | |
gr.Checkbox(label="Generate subtitles", value=True), | |
gr.Slider(minimum=1, maximum=128, value=8, step=1, label="Batch Size"), | |
gr.Slider(minimum=5, maximum=30, value=15, step=5, label="Chunk Length (seconds)"), | |
], | |
outputs=[ | |
gr.JSON(label="Transcription", open=True), | |
gr.File(label="Subtitles (SRT)", visible=True), | |
gr.Textbox(label="Correction", visible=False), # Hidden correction input | |
], | |
title="Whisper Large V3: Transcribe Audio", | |
description=( | |
"Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the" | |
f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files" | |
" of arbitrary length." | |
), | |
flagging_mode="manual", | |
flagging_options=["👍 Good", "👎 Bad"], | |
flagging_dir="flagged_data", | |
flagging_callback=flagging_callback | |
) | |
# Then set up the demo with the interfaces | |
with demo: | |
gr.TabbedInterface([file_transcribe, mf_transcribe], ["Audio file", "Microphone"]) | |
logger.info("Starting Gradio interface") | |
demo.queue().launch(ssr_mode=False) | |