reab5555 commited on
Commit
a08b017
·
verified ·
1 Parent(s): fb33e5c

Update diarization.py

Browse files
Files changed (1) hide show
  1. diarization.py +154 -157
diarization.py CHANGED
@@ -1,168 +1,165 @@
1
  import os
2
- import gradio as gr
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
- from langchain.llms import HuggingFacePipeline
5
- from langchain_community.document_loaders import TextLoader
6
- from langchain_community.vectorstores import FAISS
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain.chains import RetrievalQA
9
- from huggingface_hub import login
10
- import diarization
11
- import shutil
12
  import spaces
13
- import time
14
 
15
- # Get Hugging Face token from Space secret
16
- hf_token = os.environ.get('hf_secret')
17
- if not hf_token:
18
- raise ValueError("HF_TOKEN not found in environment variables. Please set it in the Space secrets.")
19
-
20
- # Login to Hugging Face
21
- login(token=hf_token)
22
-
23
- # Lazy initialization for the pipeline
24
- class LazyPipeline:
25
  def __init__(self):
26
  self.pipeline = None
27
 
28
- @spaces.GPU(duration=250)
29
- def get_pipeline(self):
30
  if self.pipeline is None:
31
- import torch
32
- model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
33
- tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- model = AutoModelForCausalLM.from_pretrained(
35
- model_name,
36
- torch_dtype=torch.float16,
37
- device_map="auto",
38
- )
39
- self.pipeline = pipeline(
40
- "text-generation",
41
- model=model,
42
- tokenizer=tokenizer,
43
- max_new_tokens=512,
44
- temperature=0.5,
45
- top_p=0.95,
46
- repetition_penalty=1.15
47
- )
48
  return self.pipeline
49
 
50
- lazy_pipe = LazyPipeline()
51
-
52
- # Create a LangChain wrapper around the pipeline
53
- class LazyLLM:
54
- def __init__(self, lazy_pipeline):
55
- self.lazy_pipeline = lazy_pipeline
56
- self.llm = None
57
-
58
- @spaces.GPU(duration=250)
59
- def get_llm(self):
60
- if self.llm is None:
61
- pipe = self.lazy_pipeline.get_pipeline()
62
- self.llm = HuggingFacePipeline(pipeline=pipe)
63
- return self.llm
64
-
65
- lazy_llm = LazyLLM(lazy_pipe)
66
-
67
- # Load instruction files
68
- def load_instructions(file_path):
69
- with open(file_path, 'r') as file:
70
- return file.read()
71
-
72
- general_task = load_instructions("tasks/general_task.txt")
73
- attachments_task = load_instructions("tasks/Attachments_task.txt")
74
- bigfive_task = load_instructions("tasks/BigFive_task.txt")
75
- personalities_task = load_instructions("tasks/Personalities_task.txt")
76
-
77
- # Load knowledge files
78
- def load_knowledge(file_path):
79
- with open(file_path, 'r') as file:
80
- return file.read()
81
-
82
- attachments_knowledge = load_knowledge("knowledge/bartholomew_attachments_definitions.txt")
83
- bigfive_knowledge = load_knowledge("knowledge/bigfive_definitions.txt")
84
- personalities_knowledge = load_knowledge("knowledge/personalities_definitions.txt")
85
-
86
- # Lazy initialization for retrieval chains
87
- class LazyChains:
88
- def __init__(self, lazy_llm):
89
- self.lazy_llm = lazy_llm
90
- self.attachments_chain = None
91
- self.bigfive_chain = None
92
- self.personalities_chain = None
93
 
94
  @spaces.GPU(duration=120)
95
- def get_chains(self):
96
- if self.attachments_chain is None:
97
- llm = self.lazy_llm.get_llm()
98
- self.attachments_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=attachments_knowledge)
99
- self.bigfive_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=bigfive_knowledge)
100
- self.personalities_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=personalities_knowledge)
101
- return self.attachments_chain, self.bigfive_chain, self.personalities_chain
102
-
103
- lazy_chains = LazyChains(lazy_llm)
104
-
105
- # Function to process video file
106
- @spaces.GPU(duration=120)
107
- def process_video(video_file):
108
- start_time = time.time()
109
-
110
- # Copy the uploaded video file to a temporary location
111
- temp_video_path = "temp_video.mp4"
112
- shutil.copy2(video_file.name, temp_video_path)
113
-
114
- # Initialize progress bar
115
- progress = gr.Progress()
116
-
117
- # Display progress bar for diarization
118
- progress(0, desc="Starting Diarization...")
119
- # Process the video using the diarization script
120
- language = "en"
121
- diarization.process_video(temp_video_path, hf_token, language)
122
- progress(50, desc="Diarization Complete.")
123
-
124
- # The SRT file will be created with the same name as the video file but with .srt extension
125
- srt_path = temp_video_path.replace(".mp4", "_combined.srt")
126
-
127
- # Read the content of the SRT file
128
- with open(srt_path, 'r', encoding='utf-8') as file:
129
- srt_content = file.read()
130
-
131
- # Get the chains
132
- attachments_chain, bigfive_chain, personalities_chain = lazy_chains.get_chains()
133
-
134
- # Process with LangChain and display progress bars
135
- progress(50, desc="Processing Attachments Analysis...")
136
- attachments_result = attachments_chain.run(srt_content)
137
- progress(70, desc="Attachments Analysis Complete.")
138
-
139
- progress(70, desc="Processing Big Five Analysis...")
140
- bigfive_result = bigfive_chain.run(srt_content)
141
- progress(90, desc="Big Five Analysis Complete.")
142
-
143
- progress(90, desc="Processing Personalities Analysis...")
144
- personalities_result = personalities_chain.run(srt_content)
145
- progress(100, desc="Personalities Analysis Complete.")
146
-
147
- # Combine results
148
- final_result = f"Attachments Analysis:\n{attachments_result}\n\nBig Five Analysis:\n{bigfive_result}\n\nPersonalities Analysis:\n{personalities_result}"
149
-
150
- end_time = time.time()
151
- execution_time = end_time - start_time
152
-
153
- # Only return execution time and final result
154
- final_result_with_time = f"Execution Time: {execution_time:.2f} seconds\n\n{final_result}"
155
-
156
- return final_result_with_time
157
-
158
- # Create Gradio interface
159
- iface = gr.Interface(
160
- fn=process_video,
161
- inputs=gr.File(label="Upload Video File"),
162
- outputs=gr.Textbox(label="Analysis Result"),
163
- title="Video Analysis with Meta-Llama-3.1-8B-Instruct",
164
- description="Upload a video file to analyze using RAG techniques with Meta-Llama-3.1-8B-Instruct."
165
- )
166
-
167
- # Launch the app
168
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ import math
4
+ from moviepy.editor import VideoFileClip, AudioFileClip
5
+ from pyannote.audio import Pipeline
6
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
7
+ import librosa
8
+ import datetime
9
+ from collections import defaultdict
10
+ import numpy as np
 
11
  import spaces
 
12
 
13
+ class LazyDiarizationPipeline:
 
 
 
 
 
 
 
 
 
14
  def __init__(self):
15
  self.pipeline = None
16
 
17
+ @spaces.GPU(duration=120)
18
+ def get_pipeline(self, diarization_access_token):
19
  if self.pipeline is None:
20
+ self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=diarization_access_token)
21
+ self.pipeline = self.pipeline.to(torch.device("cuda"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  return self.pipeline
23
 
24
+ lazy_diarization_pipeline = LazyDiarizationPipeline()
25
+
26
+ class LazyTranscriptionPipeline:
27
+ def __init__(self):
28
+ self.model = None
29
+ self.processor = None
30
+ self.pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  @spaces.GPU(duration=120)
33
+ def get_pipeline(self, language):
34
+ if self.pipe is None:
35
+ model_id = "openai/whisper-large-v3"
36
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
37
+ model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True
38
+ )
39
+ self.model.to(torch.device("cuda"))
40
+ self.processor = AutoProcessor.from_pretrained(model_id)
41
+ self.pipe = pipeline(
42
+ "automatic-speech-recognition",
43
+ model=self.model,
44
+ tokenizer=self.processor.tokenizer,
45
+ feature_extractor=self.processor.feature_extractor,
46
+ max_new_tokens=128,
47
+ chunk_length_s=30,
48
+ batch_size=1,
49
+ return_timestamps=True,
50
+ torch_dtype=torch.float16,
51
+ device=torch.device("cuda"),
52
+ generate_kwargs={"language": language}
53
+ )
54
+ return self.pipe
55
+
56
+ lazy_transcription_pipeline = LazyTranscriptionPipeline()
57
+
58
+ def extract_audio(video_path, audio_path):
59
+ video = VideoFileClip(video_path)
60
+ audio = video.audio
61
+ audio.write_audiofile(audio_path, codec='pcm_s16le', fps=16000)
62
+
63
+ def format_timestamp(seconds):
64
+ return str(datetime.timedelta(seconds=seconds)).split('.')[0]
65
+
66
+ @spaces.GPU(duration=100)
67
+ def transcribe_audio(audio_path, language):
68
+ pipe = lazy_transcription_pipeline.get_pipeline(language)
69
+
70
+ audio, sr = librosa.load(audio_path, sr=16000)
71
+ duration = len(audio) / sr
72
+ n_chunks = math.ceil(duration / 30)
73
+ transcription_txt = ""
74
+ transcription_chunks = []
75
+
76
+ for i in range(n_chunks):
77
+ start = i * 30 * sr
78
+ end = min((i + 1) * 30 * sr, len(audio))
79
+ audio_chunk = audio[start:end]
80
+
81
+ # Convert the audio chunk to float32 numpy array
82
+ audio_chunk = (audio_chunk * 32767).astype(np.float32)
83
+
84
+ result = pipe(audio_chunk)
85
+ transcription_txt += result["text"]
86
+ for chunk in result["chunks"]:
87
+ start_time, end_time = chunk["timestamp"]
88
+ transcription_chunks.append({
89
+ "start": start_time + i * 30,
90
+ "end": end_time + i * 30,
91
+ "text": chunk["text"]
92
+ })
93
+
94
+ print(f"Transcription Progress: {int(((i + 1) / n_chunks) * 100)}%")
95
+
96
+ return transcription_txt, transcription_chunks
97
+
98
+ def create_combined_srt(transcription_chunks, diarization, output_path):
99
+ speaker_segments = []
100
+ speaker_map = {}
101
+ current_speaker_num = 1
102
+
103
+ for segment, _, speaker in diarization.itertracks(yield_label=True):
104
+ if speaker not in speaker_map:
105
+ speaker_map[speaker] = f"Speaker {current_speaker_num}"
106
+ current_speaker_num += 1
107
+ speaker_segments.append((segment.start, segment.end, speaker_map[speaker]))
108
+
109
+ with open(output_path, 'w', encoding='utf-8') as srt_file:
110
+ for i, chunk in enumerate(transcription_chunks, 1):
111
+ start_time, end_time = chunk["start"], chunk["end"]
112
+ text = chunk["text"]
113
+
114
+ # Find the corresponding speaker
115
+ current_speaker = "Unknown"
116
+ for seg_start, seg_end, speaker in speaker_segments:
117
+ if seg_start <= start_time < seg_end:
118
+ current_speaker = speaker
119
+ break
120
+
121
+ # Format timecodes as h:mm:ss (without leading zeros for hours)
122
+ start_str = format_timestamp(start_time).split('.')[0].lstrip('0')
123
+ end_str = format_timestamp(end_time).split('.')[0].lstrip('0')
124
+
125
+ srt_file.write(f"{i}\n")
126
+ srt_file.write(f"{{{current_speaker}}}\n time: ({start_str} --> {end_str})\n text: {text}\n\n")
127
+
128
+ # Add dominant speaker information
129
+ speaker_durations = defaultdict(float)
130
+ for seg_start, seg_end, speaker in speaker_segments:
131
+ speaker_durations[speaker] += seg_end - seg_start
132
+
133
+ dominant_speaker = max(speaker_durations, key=speaker_durations.get)
134
+ dominant_duration = speaker_durations[dominant_speaker]
135
+
136
+ with open(output_path, 'a', encoding='utf-8') as srt_file:
137
+ dominant_duration_str = format_timestamp(dominant_duration).split('.')[0].lstrip('0')
138
+ srt_file.write(f"\nMost dominant speaker: {dominant_speaker} with total duration {dominant_duration_str}\n")
139
+
140
+ @spaces.GPU(duration=100)
141
+ def process_video(video_path, diarization_access_token, language):
142
+ base_name = os.path.splitext(video_path)[0]
143
+ audio_path = f"{base_name}.wav"
144
+ extract_audio(video_path, audio_path)
145
+
146
+ # Diarization
147
+ print("Performing diarization...")
148
+ pipeline = lazy_diarization_pipeline.get_pipeline(diarization_access_token)
149
+ diarization = pipeline(audio_path)
150
+ print("Diarization complete.")
151
+
152
+ # Transcription
153
+ print("Performing transcription...")
154
+ transcription, chunks = transcribe_audio(audio_path, language)
155
+ print("Transcription complete.")
156
+
157
+ # Create combined SRT file
158
+ combined_srt_path = f"{base_name}_combined.srt"
159
+ create_combined_srt(chunks, diarization, combined_srt_path)
160
+ print(f"Combined SRT file created and saved to {combined_srt_path}")
161
+
162
+ # Clean up
163
+ os.remove(audio_path)
164
+
165
+ return combined_srt_path