ssolito commited on
Commit
0373cec
·
verified ·
1 Parent(s): 038bca7

Update whisper_cs.py

Browse files
Files changed (1) hide show
  1. whisper_cs.py +38 -102
whisper_cs.py CHANGED
@@ -4,17 +4,15 @@ import os
4
  import torchaudio
5
  import torch
6
  import re
7
- from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig
8
- from pyannote.audio import Pipeline as DiarizationPipeline
9
- import whisperx
10
  import whisper_timestamped as whisper_ts
11
  from typing import Dict
 
12
 
13
  device = 0 if torch.cuda.is_available() else "cpu"
14
  torch_dtype = torch.float32
15
 
16
- MODEL_PATH_1 = "projecte-aina/whisper-large-v3-tiny-caesar"
17
- MODEL_PATH_2 = "langtech-veu/whisper-timestamped-cs"
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  def clean_text(input_text):
@@ -42,19 +40,6 @@ def split_stereo_channels(audio_path):
42
  channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
43
 
44
 
45
- def convert_to_mono(input_path):
46
- audio = AudioSegment.from_file(input_path)
47
- base, ext = os.path.splitext(input_path)
48
- output_path = f"{base}_merged.wav"
49
- print('output_path',output_path)
50
- mono = audio.set_channels(1)
51
- mono.export(output_path, format="wav")
52
- return output_path
53
-
54
- def save_temp_audio(waveform, sample_rate, path):
55
- waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform
56
- torchaudio.save(path, waveform, sample_rate)
57
-
58
  def format_audio(audio_path):
59
  input_audio, sample_rate = torchaudio.load(audio_path)
60
  if input_audio.shape[0] == 2:
@@ -63,52 +48,6 @@ def format_audio(audio_path):
63
  input_audio = resampler(input_audio)
64
  print('resampled')
65
  return input_audio.squeeze(), 16000
66
-
67
- def assign_timestamps(asr_segments, audio_path):
68
- waveform, sr = format_audio(audio_path)
69
- total_duration = waveform.shape[-1] / sr
70
-
71
- total_words = sum(len(seg["text"].split()) for seg in asr_segments)
72
- if total_words == 0:
73
- raise ValueError("Total number of words in ASR segments is zero. Cannot assign timestamps.")
74
-
75
- avg_word_duration = total_duration / total_words
76
-
77
- current_time = 0.0
78
- for segment in asr_segments:
79
- word_count = len(segment["text"].split())
80
- segment_duration = word_count * avg_word_duration
81
- segment["start"] = round(current_time, 3)
82
- segment["end"] = round(current_time + segment_duration, 3)
83
- current_time += segment_duration
84
-
85
- return asr_segments
86
-
87
- def hf_chunks_to_whisperx_segments(chunks):
88
- return [
89
- {
90
- "text": chunk["text"],
91
- "start": chunk["timestamp"][0],
92
- "end": chunk["timestamp"][1],
93
- }
94
- for chunk in chunks
95
- if chunk["timestamp"] and isinstance(chunk["timestamp"], (list, tuple))
96
- ]
97
-
98
- def align_words_to_segments(words, segments, window=5.0):
99
- aligned = []
100
- seg_idx = 0
101
- for word in words:
102
- while seg_idx < len(segments) and segments[seg_idx]["end"] < word["start"] - window:
103
- seg_idx += 1
104
- for j in range(seg_idx, len(segments)):
105
- seg = segments[j]
106
- if seg["start"] > word["end"] + window:
107
- break
108
- if seg["start"] <= word["start"] < seg["end"]:
109
- aligned.append((word, seg))
110
- break
111
- return aligned
112
 
113
  def post_process_transcription(transcription, max_repeats=2):
114
  tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
@@ -166,7 +105,7 @@ def cleanup_temp_files(*file_paths):
166
  if path and os.path.exists(path):
167
  os.remove(path)
168
 
169
-
170
 
171
  def load_whisper_model(model_path: str):
172
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -241,33 +180,47 @@ def asr(audio_path):
241
  asr_segments = assign_timestamps(asr_segments, audio_path)
242
  return asr_segments
243
 
244
- def align_asr_to_diarization(asr_segments, diarized_segments, audio_path):
245
- waveform, sample_rate = format_audio(audio_path)
246
-
247
- word_segments = whisperx.align(asr_segments, align_model, metadata, waveform, DEVICE)
248
- words = word_segments['word_segments']
249
 
250
- diarized = [{"start": segment.start,"end": segment.end,"speaker": speaker} for segment, _, speaker in diarized_segments]
 
 
 
 
 
 
 
251
 
252
- aligned_pairs = align_words_to_segments(words, diarized)
 
253
 
254
- output = []
255
- segment_map = {}
256
- for word, segment in aligned_pairs:
257
- key = (segment["start"], segment["end"], segment["speaker"])
258
- if key not in segment_map:
259
- segment_map[key] = []
260
- segment_map[key].append(word["word"])
261
 
262
- for (start, end, speaker), words in sorted(segment_map.items()):
263
- output.append(f"[{speaker}] {' '.join(words)}")
264
 
265
- return output
 
 
 
 
266
 
267
- def generate(audio_path, use_v2):
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- if use_v2:
270
- model = load_whisper_model(MODEL_PATH_2)
 
271
  split_stereo_channels(audio_path)
272
 
273
  left_channel_path = "temp_mono_speaker2.wav"
@@ -300,23 +253,6 @@ def generate(audio_path, use_v2):
300
  output += f"[{speaker}]: {text}\n"
301
 
302
  clean_output = output.strip()
303
-
304
- else:
305
- mono_audio_path = convert_to_mono(audio_path)
306
- waveform, sr = format_audio(mono_audio_path)
307
- tmp_full_path = "tmp_full.wav"
308
- save_temp_audio(waveform, sr, tmp_full_path)
309
- diarized_segments = diarization(tmp_full_path)
310
- asr_segments = asr(tmp_full_path)
311
- for segment in asr_segments:
312
- segment["text"] = post_process_transcription(segment["text"])
313
- aligned_text = align_asr_to_diarization(asr_segments, diarized_segments, tmp_full_path)
314
-
315
- clean_output = ""
316
- for line in aligned_text:
317
- clean_output += f"{line}\n"
318
- clean_output = post_merge_consecutive_segments_from_text(clean_output)
319
- cleanup_temp_files(mono_audio_path,tmp_full_path)
320
 
321
  cleanup_temp_files(
322
  "temp_mono_speaker1.wav",
 
4
  import torchaudio
5
  import torch
6
  import re
 
 
 
7
  import whisper_timestamped as whisper_ts
8
  from typing import Dict
9
+ from faster_whisper import WhisperModel
10
 
11
  device = 0 if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float32
13
 
14
+ MODEL_PATH_V2 = "langtech-veu/whisper-timestamped-cs"
15
+ MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs"
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  def clean_text(input_text):
 
40
  channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def format_audio(audio_path):
44
  input_audio, sample_rate = torchaudio.load(audio_path)
45
  if input_audio.shape[0] == 2:
 
48
  input_audio = resampler(input_audio)
49
  print('resampled')
50
  return input_audio.squeeze(), 16000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def post_process_transcription(transcription, max_repeats=2):
53
  tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
 
105
  if path and os.path.exists(path):
106
  os.remove(path)
107
 
108
+ faster_model = WhisperModel(MODEL_PATH_V2_FAST, device=DEVICE, compute_type="int8")
109
 
110
  def load_whisper_model(model_path: str):
111
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
180
  asr_segments = assign_timestamps(asr_segments, audio_path)
181
  return asr_segments
182
 
 
 
 
 
 
183
 
184
+ def generate(audio_path, use_v2_fast):
185
+
186
+ if use_v2_fast:
187
+ left_channel_path = "temp_mono_speaker2.wav"
188
+ right_channel_path = "temp_mono_speaker1.wav"
189
+
190
+ left_waveform, left_sr = format_audio(left_channel_path)
191
+ right_waveform, right_sr = format_audio(right_channel_path)
192
 
193
+ left_waveform = left_waveform.numpy().astype("float32")
194
+ right_waveform = right_waveform.numpy().astype("float32")
195
 
196
+ left_result, info = faster_model.transcribe(left_waveform, beam_size=5, task="transcribe")
197
+ right_result, info = faster_model.transcribe(right_waveform, beam_size=5, task="transcribe")
 
 
 
 
 
198
 
199
+ left_result = list(left_result)
200
+ right_result = list(right_result)
201
 
202
+ def get_faster_segments(segments, speaker_label):
203
+ return [
204
+ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
205
+ for seg in segments if seg.text
206
+ ]
207
 
208
+ left_segs = get_faster_segments(left_result, "Speaker 1")
209
+ right_segs = get_faster_segments(right_result, "Speaker 2")
210
+
211
+ merged_transcript = sorted(
212
+ left_segs + right_segs,
213
+ key=lambda x: float(x[0]) if x[0] is not None else float("inf")
214
+ )
215
+
216
+ clean_output = ""
217
+ for start, end, speaker, text in merged_transcript:
218
+ clean_output += f"[{speaker}]: {text}\n"
219
+ clean_output = post_merge_consecutive_segments_from_text(clean_output)
220
 
221
+
222
+ else:
223
+ model = load_whisper_model(MODEL_PATH_V2)
224
  split_stereo_channels(audio_path)
225
 
226
  left_channel_path = "temp_mono_speaker2.wav"
 
253
  output += f"[{speaker}]: {text}\n"
254
 
255
  clean_output = output.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  cleanup_temp_files(
258
  "temp_mono_speaker1.wav",