Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gc | |
import math | |
from moviepy.editor import VideoFileClip | |
from pyannote.audio import Pipeline | |
import librosa | |
import soundfile as sf | |
import datetime | |
from collections import defaultdict | |
import numpy as np | |
import openai | |
from config import openai_api_key | |
openai.api_key = openai_api_key | |
class LazyDiarizationPipeline: | |
def __init__(self): | |
self.pipeline = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_pipeline(self, hf_token): | |
if self.pipeline is None: | |
self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", | |
use_auth_token=hf_token) | |
self.pipeline = self.pipeline.to(self.device) | |
torch.cuda.empty_cache() | |
gc.collect() | |
return self.pipeline | |
lazy_diarization_pipeline = LazyDiarizationPipeline() | |
def extract_audio(video_path, audio_path): | |
video = VideoFileClip(video_path) | |
audio = video.audio | |
audio.write_audiofile(audio_path, codec='pcm_s16le', fps=16000) | |
def format_timestamp(seconds): | |
return str(datetime.timedelta(seconds=seconds)).split('.')[0] | |
def transcribe_audio(audio_path, language): | |
with open(audio_path, "rb") as audio_file: | |
transcript = openai.Audio.transcribe( | |
file=audio_file, | |
model="whisper-1", | |
language=language, | |
response_format="verbose_json" | |
) | |
transcription_txt = transcript["text"] | |
transcription_chunks = [] | |
for segment in transcript["segments"]: | |
transcription_chunks.append({ | |
"start": segment["start"], | |
"end": segment["end"], | |
"text": segment["text"] | |
}) | |
return transcription_txt, transcription_chunks | |
def diarize_audio(audio_path, pipeline, max_speakers): | |
# Load the entire audio file | |
audio, sr = librosa.load(audio_path, sr=16000) | |
# Write the audio to a temporary file if needed for the pipeline | |
temp_audio_path = f"{audio_path}_temp.wav" | |
sf.write(temp_audio_path, audio, sr) | |
# Perform speaker diarization on the entire audio file | |
diarization = pipeline(temp_audio_path, num_speakers=max_speakers) | |
# Clean up the temporary file | |
os.remove(temp_audio_path) | |
torch.cuda.empty_cache() | |
gc.collect() | |
return diarization | |
def create_combined_srt(transcription_chunks, diarization, output_path, max_speakers): | |
speaker_segments = [] | |
speaker_durations = defaultdict(float) | |
for segment, _, speaker in diarization.itertracks(yield_label=True): | |
speaker_durations[speaker] += segment.end - segment.start | |
speaker_segments.append((segment.start, segment.end, speaker)) | |
sorted_speakers = sorted(speaker_durations.items(), key=lambda x: x[1], reverse=True)[:max_speakers] | |
speaker_map = {} | |
for i, (speaker, _) in enumerate(sorted_speakers, start=1): | |
speaker_map[speaker] = f"Speaker {i}" | |
with open(output_path, 'w', encoding='utf-8') as srt_file: | |
for i, chunk in enumerate(transcription_chunks, 1): | |
start_time, end_time = chunk["start"], chunk["end"] | |
text = chunk["text"] | |
current_speaker = "Unknown" | |
for seg_start, seg_end, speaker in speaker_segments: | |
if seg_start <= start_time < seg_end: | |
current_speaker = speaker_map.get(speaker, "Unknown") | |
break | |
start_str = format_timestamp(start_time).split('.')[0].lstrip('0') | |
end_str = format_timestamp(end_time).split('.')[0].lstrip('0') | |
srt_file.write(f"{i}\n") | |
srt_file.write(f"{current_speaker}\n time: ({start_str} --> {end_str})\n text: {text}\n\n") | |
with open(output_path, 'a', encoding='utf-8') as srt_file: | |
for i, (speaker, duration) in enumerate(sorted_speakers, start=1): | |
duration_str = format_timestamp(duration).split('.')[0].lstrip('0') | |
srt_file.write(f"Speaker {i} (originally {speaker}): total duration {duration_str}\n") | |
def process_video(video_path, hf_token, language, max_speakers=3): | |
base_name = os.path.splitext(video_path)[0] | |
audio_path = f"{base_name}.wav" | |
extract_audio(video_path, audio_path) | |
pipeline = lazy_diarization_pipeline.get_pipeline(hf_token) | |
diarization = diarize_audio(audio_path, pipeline, max_speakers) | |
# Clear GPU memory after diarization | |
torch.cuda.empty_cache() | |
gc.collect() | |
transcription, chunks = transcribe_audio(audio_path, language) | |
# Clear GPU memory after transcription | |
torch.cuda.empty_cache() | |
gc.collect() | |
combined_srt_path = f"{base_name}_combined.srt" | |
create_combined_srt(chunks, diarization, combined_srt_path, max_speakers) | |
os.remove(audio_path) | |
# Final GPU memory clear | |
torch.cuda.empty_cache() | |
gc.collect() | |
return combined_srt_path |