reab5555 commited on
Commit
9deffb0
·
verified ·
1 Parent(s): d8a2311

Update diarization.py

Browse files
Files changed (1) hide show
  1. diarization.py +149 -148
diarization.py CHANGED
@@ -1,149 +1,150 @@
1
- import os
2
- import torch
3
- import math
4
- from moviepy.editor import VideoFileClip, AudioFileClip
5
- from pyannote.audio import Pipeline
6
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
7
- import librosa
8
- import datetime
9
- from collections import defaultdict
10
- import numpy as np
11
-
12
- def extract_audio(video_path, audio_path):
13
- video = VideoFileClip(video_path)
14
- audio = video.audio
15
- audio.write_audiofile(audio_path, codec='pcm_s16le', fps=16000)
16
-
17
- def format_timestamp(seconds):
18
- return str(datetime.timedelta(seconds=seconds)).split('.')[0]
19
-
20
- def transcribe_audio(audio_path, language):
21
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
- model_id = "openai/whisper-large-v3"
24
-
25
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
26
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
27
- )
28
- model.to(device)
29
-
30
- processor = AutoProcessor.from_pretrained(model_id)
31
-
32
- pipe = pipeline(
33
- "automatic-speech-recognition",
34
- model=model,
35
- tokenizer=processor.tokenizer,
36
- feature_extractor=processor.feature_extractor,
37
- max_new_tokens=128,
38
- chunk_length_s=30,
39
- batch_size=1,
40
- return_timestamps=True,
41
- torch_dtype=torch_dtype,
42
- device=device,
43
- generate_kwargs={"language": language}
44
- )
45
-
46
- audio, sr = librosa.load(audio_path, sr=16000)
47
- duration = len(audio) / sr
48
- n_chunks = math.ceil(duration / 30)
49
- transcription_txt = ""
50
- transcription_chunks = []
51
-
52
- for i in range(n_chunks):
53
- start = i * 30 * sr
54
- end = min((i + 1) * 30 * sr, len(audio))
55
- audio_chunk = audio[start:end]
56
-
57
- # Convert the audio chunk to float32 numpy array
58
- audio_chunk = (audio_chunk * 32767).astype(np.float32)
59
-
60
- result = pipe(audio_chunk)
61
- transcription_txt += result["text"]
62
- for chunk in result["chunks"]:
63
- start_time, end_time = chunk["timestamp"]
64
- transcription_chunks.append({
65
- "start": start_time + i * 30,
66
- "end": end_time + i * 30,
67
- "text": chunk["text"]
68
- })
69
-
70
- print(f"Transcription Progress: {int(((i + 1) / n_chunks) * 100)}%")
71
-
72
- return transcription_txt, transcription_chunks
73
-
74
- def create_combined_srt(transcription_chunks, diarization, output_path):
75
- speaker_segments = []
76
- speaker_map = {}
77
- current_speaker_num = 1
78
-
79
- for segment, _, speaker in diarization.itertracks(yield_label=True):
80
- if speaker not in speaker_map:
81
- speaker_map[speaker] = f"Speaker {current_speaker_num}"
82
- current_speaker_num += 1
83
- speaker_segments.append((segment.start, segment.end, speaker_map[speaker]))
84
-
85
- with open(output_path, 'w', encoding='utf-8') as srt_file:
86
- for i, chunk in enumerate(transcription_chunks, 1):
87
- start_time, end_time = chunk["start"], chunk["end"]
88
- text = chunk["text"]
89
-
90
- # Find the corresponding speaker
91
- current_speaker = "Unknown"
92
- for seg_start, seg_end, speaker in speaker_segments:
93
- if seg_start <= start_time < seg_end:
94
- current_speaker = speaker
95
- break
96
-
97
- # Format timecodes as h:mm:ss (without leading zeros for hours)
98
- start_str = format_timestamp(start_time).split('.')[0].lstrip('0')
99
- end_str = format_timestamp(end_time).split('.')[0].lstrip('0')
100
-
101
- srt_file.write(f"{i}\n")
102
- srt_file.write(f"{{{current_speaker}}}\n time: ({start_str} --> {end_str})\n text: {text}\n\n")
103
-
104
- # Add dominant speaker information
105
- speaker_durations = defaultdict(float)
106
- for seg_start, seg_end, speaker in speaker_segments:
107
- speaker_durations[speaker] += seg_end - seg_start
108
-
109
- dominant_speaker = max(speaker_durations, key=speaker_durations.get)
110
- dominant_duration = speaker_durations[dominant_speaker]
111
-
112
- with open(output_path, 'a', encoding='utf-8') as srt_file:
113
- dominant_duration_str = format_timestamp(dominant_duration).split('.')[0].lstrip('0')
114
- srt_file.write(f"\nMost dominant speaker: {dominant_speaker} with total duration {dominant_duration_str}\n")
115
-
116
- def process_video(video_path, diarization_access_token, language):
117
- base_name = os.path.splitext(video_path)[0]
118
- audio_path = f"{base_name}.wav"
119
- extract_audio(video_path, audio_path)
120
-
121
- # Diarization
122
- print("Performing diarization...")
123
- pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=diarization_access_token)
124
- pipeline = pipeline.to(torch.device("cpu"))
125
- diarization = pipeline(audio_path)
126
- print("Diarization complete.")
127
-
128
- # Transcription
129
- print("Performing transcription...")
130
- transcription, chunks = transcribe_audio(audio_path, language)
131
- print("Transcription complete.")
132
-
133
- # Create combined SRT file
134
- combined_srt_path = f"{base_name}_combined.srt"
135
- create_combined_srt(chunks, diarization, combined_srt_path)
136
- print(f"Combined SRT file created and saved to {combined_srt_path}")
137
-
138
- # Clean up
139
- os.remove(audio_path)
140
-
141
- if __name__ == "__main__":
142
- video_path = r"C:\Users\reab5\Downloads\MediaHuman\Music\test1.mp4"
143
- # Get Hugging Face token from Space secret
144
- access_token = os.environ.get('hf_secret')
145
- if not access_token:
146
- raise ValueError("HF_TOKEN not found in environment variables. Please set it in the Space secrets.")
147
-
148
- language = "en"
 
149
  process_video(video_path, access_token, language)
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ import math
5
+ from moviepy.editor import VideoFileClip, AudioFileClip
6
+ from pyannote.audio import Pipeline
7
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
8
+ import librosa
9
+ import datetime
10
+ from collections import defaultdict
11
+ import numpy as np
12
+
13
+ def extract_audio(video_path, audio_path):
14
+ video = VideoFileClip(video_path)
15
+ audio = video.audio
16
+ audio.write_audiofile(audio_path, codec='pcm_s16le', fps=16000)
17
+
18
+ def format_timestamp(seconds):
19
+ return str(datetime.timedelta(seconds=seconds)).split('.')[0]
20
+
21
+ def transcribe_audio(audio_path, language):
22
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+ model_id = "openai/whisper-large-v3"
25
+
26
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
27
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
28
+ )
29
+ model.to(device)
30
+
31
+ processor = AutoProcessor.from_pretrained(model_id)
32
+
33
+ pipe = pipeline(
34
+ "automatic-speech-recognition",
35
+ model=model,
36
+ tokenizer=processor.tokenizer,
37
+ feature_extractor=processor.feature_extractor,
38
+ max_new_tokens=128,
39
+ chunk_length_s=30,
40
+ batch_size=1,
41
+ return_timestamps=True,
42
+ torch_dtype=torch_dtype,
43
+ device=device,
44
+ generate_kwargs={"language": language}
45
+ )
46
+
47
+ audio, sr = librosa.load(audio_path, sr=16000)
48
+ duration = len(audio) / sr
49
+ n_chunks = math.ceil(duration / 30)
50
+ transcription_txt = ""
51
+ transcription_chunks = []
52
+
53
+ for i in range(n_chunks):
54
+ start = i * 30 * sr
55
+ end = min((i + 1) * 30 * sr, len(audio))
56
+ audio_chunk = audio[start:end]
57
+
58
+ # Convert the audio chunk to float32 numpy array
59
+ audio_chunk = (audio_chunk * 32767).astype(np.float32)
60
+
61
+ result = pipe(audio_chunk)
62
+ transcription_txt += result["text"]
63
+ for chunk in result["chunks"]:
64
+ start_time, end_time = chunk["timestamp"]
65
+ transcription_chunks.append({
66
+ "start": start_time + i * 30,
67
+ "end": end_time + i * 30,
68
+ "text": chunk["text"]
69
+ })
70
+
71
+ print(f"Transcription Progress: {int(((i + 1) / n_chunks) * 100)}%")
72
+
73
+ return transcription_txt, transcription_chunks
74
+
75
+ def create_combined_srt(transcription_chunks, diarization, output_path):
76
+ speaker_segments = []
77
+ speaker_map = {}
78
+ current_speaker_num = 1
79
+
80
+ for segment, _, speaker in diarization.itertracks(yield_label=True):
81
+ if speaker not in speaker_map:
82
+ speaker_map[speaker] = f"Speaker {current_speaker_num}"
83
+ current_speaker_num += 1
84
+ speaker_segments.append((segment.start, segment.end, speaker_map[speaker]))
85
+
86
+ with open(output_path, 'w', encoding='utf-8') as srt_file:
87
+ for i, chunk in enumerate(transcription_chunks, 1):
88
+ start_time, end_time = chunk["start"], chunk["end"]
89
+ text = chunk["text"]
90
+
91
+ # Find the corresponding speaker
92
+ current_speaker = "Unknown"
93
+ for seg_start, seg_end, speaker in speaker_segments:
94
+ if seg_start <= start_time < seg_end:
95
+ current_speaker = speaker
96
+ break
97
+
98
+ # Format timecodes as h:mm:ss (without leading zeros for hours)
99
+ start_str = format_timestamp(start_time).split('.')[0].lstrip('0')
100
+ end_str = format_timestamp(end_time).split('.')[0].lstrip('0')
101
+
102
+ srt_file.write(f"{i}\n")
103
+ srt_file.write(f"{{{current_speaker}}}\n time: ({start_str} --> {end_str})\n text: {text}\n\n")
104
+
105
+ # Add dominant speaker information
106
+ speaker_durations = defaultdict(float)
107
+ for seg_start, seg_end, speaker in speaker_segments:
108
+ speaker_durations[speaker] += seg_end - seg_start
109
+
110
+ dominant_speaker = max(speaker_durations, key=speaker_durations.get)
111
+ dominant_duration = speaker_durations[dominant_speaker]
112
+
113
+ with open(output_path, 'a', encoding='utf-8') as srt_file:
114
+ dominant_duration_str = format_timestamp(dominant_duration).split('.')[0].lstrip('0')
115
+ srt_file.write(f"\nMost dominant speaker: {dominant_speaker} with total duration {dominant_duration_str}\n")
116
+
117
+ def process_video(video_path, diarization_access_token, language):
118
+ base_name = os.path.splitext(video_path)[0]
119
+ audio_path = f"{base_name}.wav"
120
+ extract_audio(video_path, audio_path)
121
+
122
+ # Diarization
123
+ print("Performing diarization...")
124
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=diarization_access_token)
125
+ pipeline = pipeline.to(torch.device("cpu"))
126
+ diarization = pipeline(audio_path)
127
+ print("Diarization complete.")
128
+
129
+ # Transcription
130
+ print("Performing transcription...")
131
+ transcription, chunks = transcribe_audio(audio_path, language)
132
+ print("Transcription complete.")
133
+
134
+ # Create combined SRT file
135
+ combined_srt_path = f"{base_name}_combined.srt"
136
+ create_combined_srt(chunks, diarization, combined_srt_path)
137
+ print(f"Combined SRT file created and saved to {combined_srt_path}")
138
+
139
+ # Clean up
140
+ os.remove(audio_path)
141
+
142
+ if __name__ == "__main__":
143
+ video_path = r"C:\Users\reab5\Downloads\MediaHuman\Music\test1.mp4"
144
+ # Get Hugging Face token from Space secret
145
+ access_token = os.environ.get('hf_secret')
146
+ if not access_token:
147
+ raise ValueError("HF_TOKEN not found in environment variables. Please set it in the Space secrets.")
148
+
149
+ language = "en"
150
  process_video(video_path, access_token, language)