reab5555's picture
Update transcribe.py
594af05 verified
raw
history blame
3.54 kB
import os
import numpy as np
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperFeatureExtractor
from moviepy.editor import VideoFileClip, AudioFileClip
import nltk
nltk.download('punkt', quiet=True)
from nltk.tokenize import sent_tokenize
import librosa
def transcribe(video_file, transcribe_to_text=True, transcribe_to_srt=True, target_language='en'):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)
video = VideoFileClip(video_file)
audio = video.audio
duration = audio.duration
chunk_duration = 60
n_chunks = int(np.ceil(duration / chunk_duration))
full_transcription = ""
for i in range(n_chunks):
start_time = i * chunk_duration
end_time = min((i + 1) * chunk_duration, duration)
audio_chunk = audio.subclip(start_time, end_time)
temp_file_path = f"temp_audio_chunk_{i}.wav"
audio_chunk.write_audiofile(temp_file_path, codec='pcm_s16le')
try:
sound_array, _ = librosa.load(temp_file_path, sr=16000)
except Exception as e:
print(f"Error reading audio file: {e}")
continue # Skip this chunk if there's an error
if sound_array.ndim > 1:
sound_array = np.mean(sound_array, axis=1)
input_features = feature_extractor(sound_array, sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(device=device, dtype=torch_dtype)
with torch.no_grad():
if target_language:
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=target_language,
task="transcribe")
generated_ids = model.generate(input_features, max_length=448)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
full_transcription += transcription + " "
os.remove(temp_file_path)
print(f"Processed chunk {i + 1}/{n_chunks}")
# Split the transcription into sentences
sentences = sent_tokenize(full_transcription.strip())
# Estimate time for each sentence based on its length relative to the total transcription
total_chars = sum(len(s) for s in sentences)
sentence_times = []
current_time = 0
for sentence in sentences:
sentence_duration = (len(sentence) / total_chars) * duration
sentence_times.append((current_time, current_time + sentence_duration))
current_time += sentence_duration
output = ""
if transcribe_to_text:
output += "Text Transcription:\n" + full_transcription + "\n\n"
if transcribe_to_srt:
output += "SRT Transcription:\n"
for i, (sentence, (start, end)) in enumerate(zip(sentences, sentence_times), 1):
output += f"{i}\n{format_time(start)} --> {format_time(end)}\n{sentence}\n\n"
return output
def format_time(seconds):
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return f"{int(h):02d}:{int(m):02d}:{s:06.3f}".replace('.', ',')