Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gc | |
import math | |
from moviepy.editor import VideoFileClip | |
from pyannote.audio import Pipeline | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
import librosa | |
import soundfile as sf | |
import datetime | |
from collections import defaultdict | |
import numpy as np | |
class LazyDiarizationPipeline: | |
def __init__(self): | |
self.pipeline = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_pipeline(self, hf_token): | |
if self.pipeline is None: | |
self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", | |
use_auth_token=hf_token) | |
self.pipeline = self.pipeline.to(self.device) | |
torch.cuda.empty_cache() | |
gc.collect() | |
return self.pipeline | |
class LazyTranscriptionPipeline: | |
def __init__(self): | |
self.model = None | |
self.processor = None | |
self.pipe = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_pipeline(self): | |
if self.pipe is None: | |
model_id = "openai/whisper-large-v3" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
self.model.to(self.device) | |
self.processor = AutoProcessor.from_pretrained(model_id) | |
self.pipe = pipeline( | |
"automatic-speech-recognition", | |
model=self.model, | |
tokenizer=self.processor.tokenizer, | |
feature_extractor=self.processor.feature_extractor, | |
chunk_length_s=30, | |
return_timestamps=True, | |
device=self.device | |
) | |
return self.pipe | |
lazy_diarization_pipeline = LazyDiarizationPipeline() | |
lazy_transcription_pipeline = LazyTranscriptionPipeline() | |
def extract_audio(video_path, audio_path): | |
video = VideoFileClip(video_path) | |
audio = video.audio | |
audio.write_audiofile(audio_path, codec='pcm_s16le', fps=16000) | |
def format_timestamp(seconds): | |
return str(datetime.timedelta(seconds=seconds)).split('.')[0] | |
def transcribe_audio(audio_path, language): | |
pipe = lazy_transcription_pipeline.get_pipeline() | |
audio, sr = librosa.load(audio_path, sr=16000) | |
duration = len(audio) / sr | |
n_chunks = math.ceil(duration / 30) | |
transcription_txt = "" | |
transcription_chunks = [] | |
for i in range(n_chunks): | |
start = i * 30 * sr | |
end = min((i + 1) * 30 * sr, len(audio)) | |
audio_chunk = audio[start:end] | |
audio_chunk = (audio_chunk * 32767).astype(np.float32) | |
result = pipe(audio_chunk, generate_kwargs={"language": language, "task": "transcribe"}) | |
transcription_txt += result["text"] | |
for chunk in result["chunks"]: | |
start_time, end_time = chunk["timestamp"] | |
if start_time is None: | |
start_time = 0 | |
if end_time is None: | |
end_time = 0 | |
transcription_chunks.append({ | |
"start": start_time + i * 30, | |
"end": end_time + i * 30, | |
"text": chunk["text"] | |
}) | |
return transcription_txt, transcription_chunks | |
def diarize_audio(audio_path, pipeline, max_speakers): | |
# Load the entire audio file | |
audio, sr = librosa.load(audio_path, sr=16000) | |
# Write the audio to a temporary file if needed for the pipeline | |
temp_audio_path = f"{audio_path}_temp.wav" | |
sf.write(temp_audio_path, audio, sr) | |
# Perform speaker diarization on the entire audio file | |
diarization = pipeline(temp_audio_path, num_speakers=max_speakers) | |
# Clean up the temporary file | |
os.remove(temp_audio_path) | |
torch.cuda.empty_cache() | |
gc.collect() | |
return diarization | |
def create_combined_srt(transcription_chunks, diarization, output_path, max_speakers): | |
speaker_segments = [] | |
speaker_durations = defaultdict(float) | |
for segment, _, speaker in diarization.itertracks(yield_label=True): | |
speaker_durations[speaker] += segment.end - segment.start | |
speaker_segments.append((segment.start, segment.end, speaker)) | |
sorted_speakers = sorted(speaker_durations.items(), key=lambda x: x[1], reverse=True)[:max_speakers] | |
speaker_map = {} | |
for i, (speaker, _) in enumerate(sorted_speakers, start=1): | |
speaker_map[speaker] = f"Speaker {i}" | |
with open(output_path, 'w', encoding='utf-8') as srt_file: | |
for i, chunk in enumerate(transcription_chunks, 1): | |
start_time, end_time = chunk["start"], chunk["end"] | |
text = chunk["text"] | |
current_speaker = "Unknown" | |
for seg_start, seg_end, speaker in speaker_segments: | |
if seg_start <= start_time < seg_end: | |
current_speaker = speaker_map.get(speaker, "Unknown") | |
break | |
start_str = format_timestamp(start_time).split('.')[0].lstrip('0') | |
end_str = format_timestamp(end_time).split('.')[0].lstrip('0') | |
srt_file.write(f"{i}\n") | |
srt_file.write(f"{current_speaker}\n time: ({start_str} --> {end_str})\n text: {text}\n\n") | |
with open(output_path, 'a', encoding='utf-8') as srt_file: | |
for i, (speaker, duration) in enumerate(sorted_speakers, start=1): | |
duration_str = format_timestamp(duration).split('.')[0].lstrip('0') | |
srt_file.write(f"Speaker {i} (originally {speaker}): total duration {duration_str}\n") | |
def process_video(video_path, hf_token, language, max_speakers=3): | |
base_name = os.path.splitext(video_path)[0] | |
audio_path = f"{base_name}.wav" | |
extract_audio(video_path, audio_path) | |
pipeline = lazy_diarization_pipeline.get_pipeline(hf_token) | |
diarization = diarize_audio(audio_path, pipeline, max_speakers) | |
# Clear GPU memory after diarization | |
torch.cuda.empty_cache() | |
gc.collect() | |
transcription, chunks = transcribe_audio(audio_path, language) | |
# Clear GPU memory after transcription | |
torch.cuda.empty_cache() | |
gc.collect() | |
combined_srt_path = f"{base_name}_combined.srt" | |
create_combined_srt(chunks, diarization, combined_srt_path, max_speakers) | |
os.remove(audio_path) | |
# Final GPU memory clear | |
torch.cuda.empty_cache() | |
gc.collect() | |
return combined_srt_path | |