Spaces:
Sleeping
Sleeping
File size: 8,342 Bytes
12f25c0 f8366cb 12f25c0 f8366cb 12f25c0 2821cda 01ddeb4 2821cda 01ddeb4 2821cda 01ddeb4 2821cda f8366cb 2821cda 34eebab f8366cb 2821cda f8366cb 2821cda 34eebab 2821cda f8366cb 2821cda 8f87df4 2821cda 34eebab 2821cda 34eebab 2821cda f8366cb 2821cda f8366cb 2821cda 01ddeb4 2821cda e55515c 2821cda 01ddeb4 2821cda a5c9bc1 2821cda 01ddeb4 2821cda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
import math
import re
import gradio as gr
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from moviepy.editor import VideoFileClip
def timestamp_to_seconds(timestamp):
"""Convert SRT timestamp to seconds"""
# Split hours, minutes, and seconds (with milliseconds)
hours, minutes, rest = timestamp.split(':')
# Handle seconds and milliseconds (separated by comma)
seconds, milliseconds = rest.split(',')
total_seconds = (
int(hours) * 3600 +
int(minutes) * 60 +
int(seconds) +
int(milliseconds) / 1000
)
return total_seconds
def format_time(seconds):
"""Convert seconds to SRT timestamp format"""
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return f"{int(h):02d}:{int(m):02d}:{s:06.3f}".replace('.', ',')
def clean_srt_duplicates(srt_content, time_threshold=30, similarity_threshold=0.9):
"""
Remove duplicate captions within a specified time range in SRT format,
keeping only the last occurrence.
"""
# Pattern to match each SRT block, including newlines in text
srt_pattern = re.compile(r"(\d+)\n(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\n(.*?)(?=\n\n|\Z)", re.DOTALL)
# Store blocks with their timing information
blocks = []
seen_texts = {} # Track last occurrence of each text
for match in srt_pattern.finditer(srt_content):
index, start_time, end_time, text = match.groups()
text = text.strip()
# Convert start time to seconds for comparison
start_seconds = timestamp_to_seconds(start_time)
# Check for similar existing captions within the time threshold
is_duplicate = False
for existing_text, (existing_time, existing_idx) in list(seen_texts.items()):
time_diff = abs(start_seconds - existing_time)
# Check if texts are identical or very similar
if (text == existing_text or
(len(text) > 0 and len(existing_text) > 0 and
(text in existing_text or existing_text in text))):
if time_diff < time_threshold:
# Remove the previous occurrence if this is a duplicate
blocks = [b for b in blocks if b[0] != str(existing_idx)]
is_duplicate = True
break
if not is_duplicate or start_seconds - seen_texts.get(text, (0, 0))[0] >= time_threshold:
blocks.append((index, start_time, end_time, text))
seen_texts[text] = (start_seconds, len(blocks))
# Rebuild the SRT content with proper formatting and sequential numbering
cleaned_srt = []
for i, (_, start_time, end_time, text) in enumerate(blocks, 1):
cleaned_srt.append(f"{i}\n{start_time} --> {end_time}\n{text}\n\n")
return ''.join(cleaned_srt)
def transcribe(video_file, transcribe_to_text, transcribe_to_srt, language):
"""
Main transcription function that processes video files and generates
text and/or SRT transcriptions.
"""
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
try:
# Initialize model and processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=60,
batch_size=4,
return_timestamps=True,
torch_dtype=torch_dtype,
device=device,
)
if video_file is None:
yield "Error: No video file provided.", None
return
# Handle video file path
video_path = video_file.name if hasattr(video_file, 'name') else video_file
try:
video = VideoFileClip(video_path)
except Exception as e:
yield f"Error processing video file: {str(e)}", None
return
# Process video in chunks
audio = video.audio
duration = video.duration
n_chunks = math.ceil(duration / 10)
transcription_txt = ""
transcription_srt = []
for i in range(n_chunks):
start = i * 10
end = min((i + 1) * 10, duration)
audio_chunk = audio.subclip(start, end)
temp_file_path = f"temp_audio_{i}.wav"
try:
# Save audio chunk to temporary file
audio_chunk.write_audiofile(
temp_file_path,
codec='pcm_s16le',
verbose=False,
logger=None
)
# Process audio chunk
with open(temp_file_path, "rb") as temp_file:
result = pipe(
temp_file_path,
generate_kwargs={"language": language}
)
transcription_txt += result["text"]
if transcribe_to_srt:
for chunk in result["chunks"]:
start_time, end_time = chunk["timestamp"]
if start_time is not None and end_time is not None:
transcription_srt.append({
"start": start_time + i * 10,
"end": end_time + i * 10,
"text": chunk["text"].strip()
})
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
# Report progress
yield f"Progress: {int(((i + 1) / n_chunks) * 100)}%", None
# Prepare output
output = ""
srt_file_path = None
if transcribe_to_text:
output += "Text Transcription:\n" + transcription_txt.strip() + "\n\n"
if transcribe_to_srt:
output += "SRT Transcription:\n"
srt_content = ""
# Generate initial SRT content
for i, sub in enumerate(transcription_srt, 1):
srt_entry = f"{i}\n{format_time(sub['start'])} --> {format_time(sub['end'])}\n{sub['text']}\n\n"
srt_content += srt_entry
# Clean up duplicates
cleaned_srt_content = clean_srt_duplicates(srt_content)
# Save SRT content to file
srt_file_path = "transcription.srt"
with open(srt_file_path, "w", encoding="utf-8") as srt_file:
srt_file.write(cleaned_srt_content)
output += f"\nSRT file saved as: {srt_file_path}"
# Clean up video object
video.close()
yield output, srt_file_path
except Exception as e:
yield f"Error during transcription: {str(e)}", None
# Create Gradio interface
iface = gr.Interface(
fn=transcribe,
inputs=[
gr.Video(label="Upload Video"),
gr.Checkbox(label="Transcribe to Text", value=True),
gr.Checkbox(label="Transcribe to SRT", value=True),
gr.Dropdown(
choices=['en', 'he', 'it', 'es', 'fr', 'de', 'zh', 'ar'],
value='en',
label="Input Video Language"
)
],
outputs=[
gr.Textbox(label="Transcription Output"),
gr.File(label="Download SRT")
],
title="WhisperCap Video Transcription",
description="""
Upload a video file to transcribe.
""",
allow_flagging="never"
)
# Launch the interface
if __name__ == "__main__":
iface.launch(share=True) |