camparchimedes commited on
Commit
329c8dd
·
verified ·
1 Parent(s): 65ac4ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -16,29 +16,45 @@ def convert_to_wav(audio_file):
16
  return wav_file
17
 
18
  import torch
19
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  pipe = pipeline("automatic-speech-recognition", model="NbAiLab/nb-whisper-large", device=device)
22
 
23
  def transcribe_audio(audio_file):
 
24
  if audio_file.endswith(".m4a"):
25
  audio_file = convert_to_wav(audio_file)
26
 
27
  start_time = time.time()
28
 
29
- # with torch.no_grad():
30
  output = pipe(audio_file)
31
 
 
32
  text = output["text"]
33
  end_time = time.time()
34
  output_time = end_time - start_time
35
  word_count = len(text.split())
36
 
 
37
  result = f"Time taken: {output_time:.2f} seconds\nNumber of words: {word_count}"
38
 
39
  return text, result
40
 
41
-
42
  import nltk
43
  from nltk.tokenize import word_tokenize, sent_tokenize
44
  from nltk.corpus import stopwords
 
16
  return wav_file
17
 
18
  import torch
19
+ from transformers import pipeline, WhisperForConditionalGeneration
20
+
21
+ # prepare decoder input IDs for generation
22
+ def prepare_decoder_input_ids_for_generation_patch(self, batch_size, model_input_name, model_kwargs, decoder_start_token_id, bos_token_id, device):
23
+ if 'decoder_input_ids' not in model_kwargs:
24
+ return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id, model_kwargs
25
+ else:
26
+ return model_kwargs.pop('decoder_input_ids'), model_kwargs
27
+
28
+ # patch function in the WhisperForConditionalGeneration class
29
+ WhisperForConditionalGeneration._prepare_decoder_input_ids_for_generation = prepare_decoder_input_ids_for_generation_patch
30
+
31
+ # print("Custom decoder input ID preparation function applied.")
32
+
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ # pipe it
35
  pipe = pipeline("automatic-speech-recognition", model="NbAiLab/nb-whisper-large", device=device)
36
 
37
  def transcribe_audio(audio_file):
38
+
39
  if audio_file.endswith(".m4a"):
40
  audio_file = convert_to_wav(audio_file)
41
 
42
  start_time = time.time()
43
 
44
+ # transcribe
45
  output = pipe(audio_file)
46
 
47
+ # get text
48
  text = output["text"]
49
  end_time = time.time()
50
  output_time = end_time - start_time
51
  word_count = len(text.split())
52
 
53
+ # summary
54
  result = f"Time taken: {output_time:.2f} seconds\nNumber of words: {word_count}"
55
 
56
  return text, result
57
 
 
58
  import nltk
59
  from nltk.tokenize import word_tokenize, sent_tokenize
60
  from nltk.corpus import stopwords