reab5555 commited on
Commit
b9f27c7
·
verified ·
1 Parent(s): 55553bb

Update diarization.py

Browse files
Files changed (1) hide show
  1. diarization.py +24 -4
diarization.py CHANGED
@@ -43,9 +43,13 @@ class LazyTranscriptionPipeline:
43
  model=self.model,
44
  tokenizer=self.processor.tokenizer,
45
  feature_extractor=self.processor.feature_extractor,
 
46
  chunk_length_s=30,
 
47
  return_timestamps=True,
48
- device=torch.device("cuda")
 
 
49
  )
50
  return self.pipe
51
 
@@ -74,6 +78,9 @@ def transcribe_audio(audio_path, language):
74
  end = min((i + 1) * 30 * sr, len(audio))
75
  audio_chunk = audio[start:end]
76
 
 
 
 
77
  result = pipe(audio_chunk)
78
  transcription_txt += result["text"]
79
  for chunk in result["chunks"]:
@@ -84,6 +91,8 @@ def transcribe_audio(audio_path, language):
84
  "text": chunk["text"]
85
  })
86
 
 
 
87
  return transcription_txt, transcription_chunks
88
 
89
  def create_combined_srt(transcription_chunks, diarization, output_path):
@@ -102,19 +111,21 @@ def create_combined_srt(transcription_chunks, diarization, output_path):
102
  start_time, end_time = chunk["start"], chunk["end"]
103
  text = chunk["text"]
104
 
 
105
  current_speaker = "Unknown"
106
  for seg_start, seg_end, speaker in speaker_segments:
107
  if seg_start <= start_time < seg_end:
108
  current_speaker = speaker
109
  break
110
 
 
111
  start_str = format_timestamp(start_time).split('.')[0].lstrip('0')
112
  end_str = format_timestamp(end_time).split('.')[0].lstrip('0')
113
 
114
  srt_file.write(f"{i}\n")
115
- srt_file.write(f"{start_str} --> {end_str}\n")
116
- srt_file.write(f"{current_speaker}: {text}\n\n")
117
 
 
118
  speaker_durations = defaultdict(float)
119
  for seg_start, seg_end, speaker in speaker_segments:
120
  speaker_durations[speaker] += seg_end - seg_start
@@ -132,14 +143,23 @@ def process_video(video_path, diarization_access_token, language):
132
  audio_path = f"{base_name}.wav"
133
  extract_audio(video_path, audio_path)
134
 
 
 
135
  pipeline = lazy_diarization_pipeline.get_pipeline(diarization_access_token)
136
  diarization = pipeline(audio_path)
 
137
 
 
 
138
  transcription, chunks = transcribe_audio(audio_path, language)
 
139
 
 
140
  combined_srt_path = f"{base_name}_combined.srt"
141
  create_combined_srt(chunks, diarization, combined_srt_path)
 
142
 
 
143
  os.remove(audio_path)
144
 
145
- return combined_srt_path
 
43
  model=self.model,
44
  tokenizer=self.processor.tokenizer,
45
  feature_extractor=self.processor.feature_extractor,
46
+ max_new_tokens=128,
47
  chunk_length_s=30,
48
+ batch_size=1,
49
  return_timestamps=True,
50
+ torch_dtype=torch.float16,
51
+ device=torch.device("cuda"),
52
+ generate_kwargs={"language": language}
53
  )
54
  return self.pipe
55
 
 
78
  end = min((i + 1) * 30 * sr, len(audio))
79
  audio_chunk = audio[start:end]
80
 
81
+ # Convert the audio chunk to float32 numpy array
82
+ audio_chunk = (audio_chunk * 32767).astype(np.float32)
83
+
84
  result = pipe(audio_chunk)
85
  transcription_txt += result["text"]
86
  for chunk in result["chunks"]:
 
91
  "text": chunk["text"]
92
  })
93
 
94
+ print(f"Transcription Progress: {int(((i + 1) / n_chunks) * 100)}%")
95
+
96
  return transcription_txt, transcription_chunks
97
 
98
  def create_combined_srt(transcription_chunks, diarization, output_path):
 
111
  start_time, end_time = chunk["start"], chunk["end"]
112
  text = chunk["text"]
113
 
114
+ # Find the corresponding speaker
115
  current_speaker = "Unknown"
116
  for seg_start, seg_end, speaker in speaker_segments:
117
  if seg_start <= start_time < seg_end:
118
  current_speaker = speaker
119
  break
120
 
121
+ # Format timecodes as h:mm:ss (without leading zeros for hours)
122
  start_str = format_timestamp(start_time).split('.')[0].lstrip('0')
123
  end_str = format_timestamp(end_time).split('.')[0].lstrip('0')
124
 
125
  srt_file.write(f"{i}\n")
126
+ srt_file.write(f"{{{current_speaker}}}\n time: ({start_str} --> {end_str})\n text: {text}\n\n")
 
127
 
128
+ # Add dominant speaker information
129
  speaker_durations = defaultdict(float)
130
  for seg_start, seg_end, speaker in speaker_segments:
131
  speaker_durations[speaker] += seg_end - seg_start
 
143
  audio_path = f"{base_name}.wav"
144
  extract_audio(video_path, audio_path)
145
 
146
+ # Diarization
147
+ print("Performing diarization...")
148
  pipeline = lazy_diarization_pipeline.get_pipeline(diarization_access_token)
149
  diarization = pipeline(audio_path)
150
+ print("Diarization complete.")
151
 
152
+ # Transcription
153
+ print("Performing transcription...")
154
  transcription, chunks = transcribe_audio(audio_path, language)
155
+ print("Transcription complete.")
156
 
157
+ # Create combined SRT file
158
  combined_srt_path = f"{base_name}_combined.srt"
159
  create_combined_srt(chunks, diarization, combined_srt_path)
160
+ print(f"Combined SRT file created and saved to {combined_srt_path}")
161
 
162
+ # Clean up
163
  os.remove(audio_path)
164
 
165
+ return combined_srt_path