LAP-DEV commited on
Commit
74a7e9b
·
verified ·
1 Parent(s): 76b1e23

Delete modules/whisper/whisper_base_old.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base_old.py +0 -758
modules/whisper/whisper_base_old.py DELETED
@@ -1,758 +0,0 @@
1
- import os
2
- import torch
3
- import whisper
4
- import gradio as gr
5
- import torchaudio
6
- from abc import ABC, abstractmethod
7
- from typing import BinaryIO, Union, Tuple, List
8
- import numpy as np
9
- from datetime import datetime
10
- from faster_whisper.vad import VadOptions
11
- from dataclasses import astuple
12
- import gc
13
- from copy import deepcopy
14
- from modules.vad.silero_vad import merge_chunks, Segment
15
- from modules.uvr.music_separator import MusicSeparator
16
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
- UVR_MODELS_DIR)
18
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_plaintext, get_csv, write_file, safe_filename
19
- from modules.utils.youtube_manager import get_ytdata, get_ytaudio
20
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
21
- from modules.whisper.whisper_parameter import *
22
- from modules.diarize.diarizer import Diarizer
23
- from modules.vad.silero_vad import SileroVAD
24
- from modules.translation.nllb_inference import NLLBInference
25
- from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
26
- import faster_whisper
27
-
28
- class WhisperBase(ABC):
29
- def __init__(self,
30
- model_dir: str = WHISPER_MODELS_DIR,
31
- diarization_model_dir: str = DIARIZATION_MODELS_DIR,
32
- uvr_model_dir: str = UVR_MODELS_DIR,
33
- output_dir: str = OUTPUT_DIR,
34
- ):
35
- self.model_dir = model_dir
36
- self.output_dir = output_dir
37
- os.makedirs(self.output_dir, exist_ok=True)
38
- os.makedirs(self.model_dir, exist_ok=True)
39
- self.diarizer = Diarizer(
40
- model_dir=diarization_model_dir
41
- )
42
- self.vad = SileroVAD()
43
- self.music_separator = MusicSeparator(
44
- model_dir=uvr_model_dir,
45
- output_dir=os.path.join(output_dir, "UVR")
46
- )
47
-
48
- self.model = None
49
- self.current_model_size = None
50
- self.available_models = whisper.available_models()
51
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
52
- #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
53
- self.translatable_models = whisper.available_models()
54
- self.device = self.get_device()
55
- self.available_compute_types = ["float16", "float32"]
56
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
57
-
58
- @abstractmethod
59
- def transcribe(self,
60
- audio: Union[str, BinaryIO, np.ndarray],
61
- progress: gr.Progress = gr.Progress(),
62
- *whisper_params,
63
- ):
64
- """Inference whisper model to transcribe"""
65
- pass
66
-
67
- @abstractmethod
68
- def update_model(self,
69
- model_size: str,
70
- compute_type: str,
71
- progress: gr.Progress = gr.Progress()
72
- ):
73
- """Initialize whisper model"""
74
- pass
75
-
76
- def run(self,
77
- audio: Union[str, BinaryIO, np.ndarray],
78
- progress: gr.Progress = gr.Progress(),
79
- add_timestamp: bool = True,
80
- *whisper_params,
81
- ) -> Tuple[List[dict], float]:
82
- """
83
- Run transcription with conditional pre-processing and post-processing.
84
- The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
85
- The diarization will be performed in post-processing, if enabled.
86
-
87
- Parameters
88
- ----------
89
- audio: Union[str, BinaryIO, np.ndarray]
90
- Audio input. This can be file path or binary type.
91
- progress: gr.Progress
92
- Indicator to show progress directly in gradio.
93
- add_timestamp: bool
94
- Whether to add a timestamp at the end of the filename.
95
- *whisper_params: tuple
96
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
97
-
98
- Returns
99
- ----------
100
- segments_result: List[dict]
101
- list of dicts that includes start, end timestamps and transcribed text
102
- elapsed_time: float
103
- elapsed time for running
104
- """
105
-
106
- start_time = datetime.now()
107
- params = WhisperParameters.as_value(*whisper_params)
108
-
109
- # Get the offload params
110
- default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
111
- whisper_params = default_params["whisper"]
112
- diarization_params = default_params["diarization"]
113
- bool_whisper_enable_offload = whisper_params["enable_offload"]
114
- bool_diarization_enable_offload = diarization_params["enable_offload"]
115
-
116
- if params.lang is None:
117
- pass
118
- elif params.lang == "Automatic Detection":
119
- params.lang = None
120
- else:
121
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
122
- params.lang = language_code_dict[params.lang]
123
-
124
- if params.is_bgm_separate:
125
- music, audio, _ = self.music_separator.separate(
126
- audio=audio,
127
- model_name=params.uvr_model_size,
128
- device=params.uvr_device,
129
- segment_size=params.uvr_segment_size,
130
- save_file=params.uvr_save_file,
131
- progress=progress
132
- )
133
-
134
- if audio.ndim >= 2:
135
- audio = audio.mean(axis=1)
136
- if self.music_separator.audio_info is None:
137
- origin_sample_rate = 16000
138
- else:
139
- origin_sample_rate = self.music_separator.audio_info.sample_rate
140
- audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
141
-
142
- if params.uvr_enable_offload:
143
- self.music_separator.offload()
144
- elapsed_time_bgm_sep = datetime.now() - start_time
145
-
146
- origin_audio = deepcopy(audio)
147
-
148
- if params.vad_filter:
149
- # Explicit value set for float('inf') from gr.Number()
150
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
151
- params.max_speech_duration_s = float('inf')
152
-
153
- progress(0, desc="Filtering silent parts from audio...")
154
- vad_options = VadOptions(
155
- threshold=params.threshold,
156
- min_speech_duration_ms=params.min_speech_duration_ms,
157
- max_speech_duration_s=params.max_speech_duration_s,
158
- min_silence_duration_ms=params.min_silence_duration_ms,
159
- speech_pad_ms=params.speech_pad_ms
160
- )
161
-
162
- vad_processed, speech_chunks = self.vad.run(
163
- audio=audio,
164
- vad_parameters=vad_options,
165
- progress=progress
166
- )
167
-
168
- try:
169
- if vad_processed.size > 0 and speech_chunks:
170
- if not isinstance(audio, np.ndarray):
171
- loaded_audio = faster_whisper.decode_audio(audio, sampling_rate=self.vad.sampling_rate)
172
- else:
173
- loaded_audio = audio
174
- # Convert speech_chunks to Segment objects and convert samples to seconds
175
- segments = [Segment(start=chunk['start']/self.vad.sampling_rate, end=chunk['end']/self.vad.sampling_rate) for chunk in speech_chunks]
176
- # merged_chunks only works on segments expressed in seconds!!
177
- merged_chunks = merge_chunks(segments, chunk_size=300, onset=0.0, offset=None)
178
- all_segments = []
179
- total_elapsed_time = 0.0
180
- for merged in merged_chunks:
181
- chunk_start = merged['start']
182
- chunk_end = merged['end']
183
-
184
- # To slice audio, convert chunk_start and chunk_end from seconds to samples by mulitplying by sampling rate.
185
- start_sample = int(chunk_start*self.vad.sampling_rate)
186
- end_sample = int(chunk_end*self.vad.sampling_rate)
187
-
188
- chunk_audio = loaded_audio[start_sample:end_sample]
189
-
190
- chunk_result, chunk_time = self.transcribe(
191
- chunk_audio,
192
- progress,
193
- *astuple(params)
194
- )
195
- # Offset timestamps
196
- for seg in chunk_result:
197
- seg['start'] += chunk_start
198
- seg['end'] += chunk_start
199
- all_segments.extend(chunk_result)
200
- total_elapsed_time += chunk_time
201
- result = all_segments
202
- elapsed_time = total_elapsed_time
203
- else:
204
- params.vad_filter = False
205
- except Exception as e:
206
- print(f"Error transcribing file: {e}")
207
-
208
- if not params.vad_filter:
209
- result, elapsed_time = self.transcribe(
210
- audio,
211
- progress,
212
- *astuple(params)
213
- )
214
- if bool_whisper_enable_offload:
215
- self.offload()
216
-
217
- if params.is_diarize:
218
- progress(0.99, desc="Diarizing speakers...")
219
- result, elapsed_time_diarization = self.diarizer.run(
220
- audio=origin_audio,
221
- use_auth_token=params.hf_token,
222
- transcribed_result=result,
223
- device=params.diarization_device
224
- )
225
- if bool_diarization_enable_offload:
226
- self.diarizer.offload()
227
-
228
- if not result:
229
- print(f"Whisper did not detected any speech segments in the audio.")
230
- result = list()
231
-
232
- progress(1.0, desc="Processing done!")
233
- total_elapsed_time = datetime.now() - start_time
234
- return result, elapsed_time
235
-
236
- def transcribe_file(self,
237
- files_audio: Optional[List] = None,
238
- files_video: Optional[List] = None,
239
- files_multi: Optional[List] = None,
240
- input_multi: str = "Audio",
241
- input_folder_path: Optional[str] = None,
242
- file_format: list = ["CSV"],
243
- add_timestamp: bool = True,
244
- translate_output: bool = False,
245
- translate_model: str = "",
246
- target_lang: str = "",
247
- add_timestamp_preview: bool = False,
248
- progress=gr.Progress(),
249
- *whisper_params,
250
- ) -> list:
251
- """
252
- Write subtitle file from Files
253
-
254
- Parameters
255
- ----------
256
- files_audio: list
257
- List of files to transcribe from gr.Audio()
258
- files_video: list
259
- List of files to transcribe from gr.Video()
260
- files_multi: list
261
- List of files to transcribe from gr.Files_multi()
262
- input_multi: bool
263
- Process single or multiple files
264
- input_folder_path: str
265
- Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
266
- this will be used instead.
267
- file_format: str
268
- Subtitle File format to write from gr.Dropdown(). Supported format: [CSV, SRT, TXT]
269
- add_timestamp: bool
270
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
271
- translate_output: bool
272
- Translate output
273
- translate_model: str
274
- Translation model to use
275
- target_lang: str
276
- Target language to use
277
- add_timestamp_preview: bool
278
- Boolean value from gr.Checkbox() that determines whether to add a timestamp to output preview
279
- progress: gr.Progress
280
- Indicator to show progress directly in gradio.
281
- *whisper_params: tuple
282
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
283
-
284
- Returns
285
- ----------
286
- result_str:
287
- Result of transcription to return to gr.Textbox()
288
- result_file_path:
289
- Output file path to return to gr.Files()
290
- """
291
-
292
- try:
293
- file_count_total = 0
294
- files = ""
295
-
296
- if input_multi == "Audio":
297
- files = files_audio
298
- elif input_multi == "Video":
299
- files = files_video
300
- else:
301
- files = files_multi
302
- file_count_total = len(files)
303
-
304
- if input_folder_path:
305
- files = get_media_files(input_folder_path)
306
- if isinstance(files, str):
307
- files = [files]
308
- if files and isinstance(files[0], gr.utils.NamedString):
309
- files = [file.name for file in files]
310
-
311
- ## Initialization variables & start time
312
- files_info = {}
313
- files_to_download = {}
314
- time_start = datetime.now()
315
-
316
- ## Load parameters related with whisper
317
- params = WhisperParameters.as_value(*whisper_params)
318
-
319
- ## Load model to detect language
320
- model = whisper.load_model("base")
321
-
322
- for file in files:
323
- print(file)
324
- ## Detect language
325
- mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
326
- _, probs = model.detect_language(mel)
327
- file_language = ""
328
- file_lang_probs = ""
329
- for key,value in whisper.tokenizer.LANGUAGES.items():
330
- if key == str(max(probs, key=probs.get)):
331
- file_language = value.capitalize()
332
- for key_prob,value_prob in probs.items():
333
- if key == key_prob:
334
- file_lang_probs = str((round(value_prob*100)))
335
- break
336
- break
337
- transcribed_segments, time_for_task = self.run(
338
- file,
339
- progress,
340
- add_timestamp,
341
- *whisper_params,
342
- )
343
- # Define source language
344
- #source_lang = file_language
345
- if params.lang == "Automatic Detection" or (params.lang).strip() == "":
346
- source_lang = file_language
347
- else:
348
- source_lang = ((params.lang).strip()).capitalize()
349
-
350
- # Translate to English using Whisper built-in functionality
351
- transcription_note = ""
352
- if params.is_translate:
353
- if source_lang != "English":
354
- transcription_note = "To English"
355
- source_lang = "English"
356
- else:
357
- transcription_note = "Already in English"
358
-
359
- # Translate the transcribed segments
360
- translation_note = ""
361
- if translate_output:
362
- if source_lang != target_lang:
363
- self.nllb_inf = NLLBInference()
364
- if source_lang in NLLB_AVAILABLE_LANGS.keys():
365
- transcribed_segments = self.nllb_inf.translate_text(
366
- input_list_dict=transcribed_segments,
367
- model_size=translate_model,
368
- src_lang=source_lang,
369
- tgt_lang=target_lang,
370
- speaker_diarization=params.is_diarize
371
- )
372
- translation_note = "To " + target_lang
373
- else:
374
- translation_note = source_lang + " not supported"
375
- else:
376
- translation_note = "Already in " + target_lang
377
-
378
- ## Get input filename & extension
379
- file_name, file_ext = os.path.splitext(os.path.basename(file))
380
-
381
- ## Get output as preview with or without timestamps
382
- if add_timestamp_preview:
383
- subtitle = get_txt(transcribed_segments)
384
- else:
385
- subtitle = get_plaintext(transcribed_segments)
386
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
387
-
388
- ## Add output file as txt, srt and/or csv
389
- for output_format in file_format:
390
- subtitle, file_path = self.generate_and_write_file(
391
- file_name=file_name,
392
- transcribed_segments=transcribed_segments,
393
- add_timestamp=add_timestamp,
394
- file_format=output_format.lower(),
395
- output_dir=self.output_dir
396
- )
397
- files_to_download[file_name+"_"+output_format.lower()] = {"path": file_path}
398
-
399
- total_result = ""
400
- total_info = ""
401
- total_time = 0
402
- file_count = 0
403
- for file_name, info in files_info.items():
404
-
405
- file_count += 1
406
-
407
- if file_count > 1:
408
- total_info += f'\n'
409
-
410
- if file_count_total > 1:
411
- if file_count > 1:
412
- total_result += f'\n'
413
- total_result += f'« Transcription of media file \'{info["input_source_file"]}\': »\n\n'
414
-
415
- total_time += info["time_for_task"]
416
- total_result += f'{info["subtitle"]}'
417
- total_info += f'Media file:\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
418
-
419
- if params.is_translate:
420
- total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
421
-
422
- if translate_output:
423
- total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
424
-
425
- time_end = datetime.now()
426
- #total_info += f"\nTotal processing time:\t{self.format_time((time_end-time_start).total_seconds())}"
427
-
428
- temp_file_count_text = "file"
429
- if file_count!=1:
430
- temp_file_count_text += "s"
431
- total_info += f"\nProcessed {file_count} {temp_file_count_text} in {self.format_time((time_end-time_start).total_seconds())}"
432
-
433
- result_str = total_result.rstrip("\n")
434
- result_file_path = [info['path'] for info in files_to_download.values()]
435
-
436
- return [result_str,result_file_path,total_info]
437
-
438
- except Exception as e:
439
- print(f"Error transcribing file: {e}")
440
- finally:
441
- self.release_cuda_memory()
442
-
443
- def transcribe_mic(self,
444
- mic_audio: str,
445
- file_format: str = "SRT",
446
- add_timestamp: bool = True,
447
- progress=gr.Progress(),
448
- *whisper_params,
449
- ) -> list:
450
- """
451
- Write subtitle file from microphone
452
-
453
- Parameters
454
- ----------
455
- mic_audio: str
456
- Audio file path from gr.Microphone()
457
- file_format: str
458
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
459
- add_timestamp: bool
460
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
461
- progress: gr.Progress
462
- Indicator to show progress directly in gradio.
463
- *whisper_params: tuple
464
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
465
-
466
- Returns
467
- ----------
468
- result_str:
469
- Result of transcription to return to gr.Textbox()
470
- result_file_path:
471
- Output file path to return to gr.Files()
472
- """
473
- try:
474
- progress(0, desc="Loading Audio...")
475
- transcribed_segments, time_for_task = self.run(
476
- mic_audio,
477
- progress,
478
- add_timestamp,
479
- *whisper_params,
480
- )
481
- progress(1, desc="Completed!")
482
-
483
- subtitle, result_file_path = self.generate_and_write_file(
484
- file_name="Mic",
485
- transcribed_segments=transcribed_segments,
486
- add_timestamp=add_timestamp,
487
- file_format=file_format,
488
- output_dir=self.output_dir
489
- )
490
-
491
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
492
- return [result_str, result_file_path]
493
- except Exception as e:
494
- print(f"Error transcribing file: {e}")
495
- finally:
496
- self.release_cuda_memory()
497
-
498
- def transcribe_youtube(self,
499
- youtube_link: str,
500
- file_format: str = "SRT",
501
- add_timestamp: bool = True,
502
- progress=gr.Progress(),
503
- *whisper_params,
504
- ) -> list:
505
- """
506
- Write subtitle file from Youtube
507
-
508
- Parameters
509
- ----------
510
- youtube_link: str
511
- URL of the Youtube video to transcribe from gr.Textbox()
512
- file_format: str
513
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
514
- add_timestamp: bool
515
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
516
- progress: gr.Progress
517
- Indicator to show progress directly in gradio.
518
- *whisper_params: tuple
519
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
520
-
521
- Returns
522
- ----------
523
- result_str:
524
- Result of transcription to return to gr.Textbox()
525
- result_file_path:
526
- Output file path to return to gr.Files()
527
- """
528
- try:
529
- progress(0, desc="Loading Audio from Youtube...")
530
- yt = get_ytdata(youtube_link)
531
- audio = get_ytaudio(yt)
532
-
533
- transcribed_segments, time_for_task = self.run(
534
- audio,
535
- progress,
536
- add_timestamp,
537
- *whisper_params,
538
- )
539
-
540
- progress(1, desc="Completed!")
541
-
542
- file_name = safe_filename(yt.title)
543
- subtitle, result_file_path = self.generate_and_write_file(
544
- file_name=file_name,
545
- transcribed_segments=transcribed_segments,
546
- add_timestamp=add_timestamp,
547
- file_format=file_format,
548
- output_dir=self.output_dir
549
- )
550
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
551
-
552
- if os.path.exists(audio):
553
- os.remove(audio)
554
-
555
- return [result_str, result_file_path]
556
-
557
- except Exception as e:
558
- print(f"Error transcribing file: {e}")
559
- finally:
560
- self.release_cuda_memory()
561
-
562
- @staticmethod
563
- def generate_and_write_file(file_name: str,
564
- transcribed_segments: list,
565
- add_timestamp: bool,
566
- file_format: str,
567
- output_dir: str
568
- ) -> str:
569
- """
570
- Writes subtitle file
571
-
572
- Parameters
573
- ----------
574
- file_name: str
575
- Output file name
576
- transcribed_segments: list
577
- Text segments transcribed from audio
578
- add_timestamp: bool
579
- Determines whether to add a timestamp to the end of the filename.
580
- file_format: str
581
- File format to write. Supported formats: [SRT, WebVTT, txt, csv]
582
- output_dir: str
583
- Directory path of the output
584
-
585
- Returns
586
- ----------
587
- content: str
588
- Result of the transcription
589
- output_path: str
590
- output file path
591
- """
592
- if add_timestamp:
593
- #timestamp = datetime.now().strftime("%m%d%H%M%S")
594
- timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
595
- output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
596
- else:
597
- output_path = os.path.join(output_dir, f"{file_name}")
598
-
599
- file_format = file_format.strip().lower()
600
- if file_format == "srt":
601
- content = get_srt(transcribed_segments)
602
- output_path += '.srt'
603
-
604
- elif file_format == "webvtt":
605
- content = get_vtt(transcribed_segments)
606
- output_path += '.vtt'
607
-
608
- elif file_format == "txt":
609
- content = get_txt(transcribed_segments)
610
- output_path += '.txt'
611
-
612
- elif file_format == "csv":
613
- content = get_csv(transcribed_segments)
614
- output_path += '.csv'
615
-
616
- write_file(content, output_path)
617
- return content, output_path
618
-
619
- def offload(self):
620
- """Offload the model and free up the memory"""
621
- if self.model is not None:
622
- del self.model
623
- self.model = None
624
- if self.device == "cuda":
625
- self.release_cuda_memory()
626
- gc.collect()
627
-
628
- @staticmethod
629
- def format_time(elapsed_time: float) -> str:
630
- """
631
- Get {hours} {minutes} {seconds} time format string
632
-
633
- Parameters
634
- ----------
635
- elapsed_time: str
636
- Elapsed time for transcription
637
-
638
- Returns
639
- ----------
640
- Time format string
641
- """
642
- hours, rem = divmod(elapsed_time, 3600)
643
- minutes, seconds = divmod(rem, 60)
644
-
645
- time_str = ""
646
-
647
- hours = round(hours)
648
- if hours:
649
- if hours == 1:
650
- time_str += f"{hours} hour "
651
- else:
652
- time_str += f"{hours} hours "
653
-
654
- minutes = round(minutes)
655
- if minutes:
656
- if minutes == 1:
657
- time_str += f"{minutes} minute "
658
- else:
659
- time_str += f"{minutes} minutes "
660
-
661
- seconds = round(seconds)
662
- if seconds == 1:
663
- time_str += f"{seconds} second"
664
- else:
665
- time_str += f"{seconds} seconds"
666
-
667
- return time_str.strip()
668
-
669
- @staticmethod
670
- def get_device():
671
- if torch.cuda.is_available():
672
- return "cuda"
673
- elif torch.backends.mps.is_available():
674
- if not WhisperBase.is_sparse_api_supported():
675
- # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
676
- return "cpu"
677
- return "mps"
678
- else:
679
- return "cpu"
680
-
681
- @staticmethod
682
- def is_sparse_api_supported():
683
- if not torch.backends.mps.is_available():
684
- return False
685
-
686
- try:
687
- device = torch.device("mps")
688
- sparse_tensor = torch.sparse_coo_tensor(
689
- indices=torch.tensor([[0, 1], [2, 3]]),
690
- values=torch.tensor([1, 2]),
691
- size=(4, 4),
692
- device=device
693
- )
694
- return True
695
- except RuntimeError:
696
- return False
697
-
698
- @staticmethod
699
- def release_cuda_memory():
700
- """Release memory"""
701
- if torch.cuda.is_available():
702
- torch.cuda.empty_cache()
703
- torch.cuda.reset_max_memory_allocated()
704
-
705
- @staticmethod
706
- def remove_input_files(file_paths: List[str]):
707
- """Remove gradio cached files"""
708
- if not file_paths:
709
- return
710
-
711
- for file_path in file_paths:
712
- if file_path and os.path.exists(file_path):
713
- os.remove(file_path)
714
-
715
- @staticmethod
716
- def cache_parameters(
717
- params: WhisperValues,
718
- file_format: str = "SRT",
719
- add_timestamp: bool = True
720
- ):
721
- """Cache parameters to the yaml file"""
722
- cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
723
- param_to_cache = params.to_dict()
724
-
725
- cached_yaml = {**cached_params, **param_to_cache}
726
- cached_yaml["whisper"]["add_timestamp"] = add_timestamp
727
- cached_yaml["whisper"]["file_format"] = file_format
728
-
729
- suppress_token = cached_yaml["whisper"].get("suppress_tokens", None)
730
- if suppress_token and isinstance(suppress_token, list):
731
- cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token)
732
-
733
- if cached_yaml["whisper"].get("lang", None) is None:
734
- cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
735
- else:
736
- language_dict = whisper.tokenizer.LANGUAGES
737
- cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
738
-
739
- if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
740
- cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
741
-
742
- if cached_yaml is not None and cached_yaml:
743
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
744
-
745
- @staticmethod
746
- def resample_audio(audio: Union[str, np.ndarray],
747
- new_sample_rate: int = 16000,
748
- original_sample_rate: Optional[int] = None,) -> np.ndarray:
749
- """Resamples audio to 16k sample rate, standard on Whisper model"""
750
- if isinstance(audio, str):
751
- audio, original_sample_rate = torchaudio.load(audio)
752
- else:
753
- if original_sample_rate is None:
754
- raise ValueError("original_sample_rate must be provided when audio is numpy array.")
755
- audio = torch.from_numpy(audio)
756
- resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
757
- resampled_audio = resampler(audio).numpy()
758
- return resampled_audio