Michael Natanael commited on
Commit
268f7eb
·
1 Parent(s): 7c09bf0

change whisper_open_ai to faster_whisper

Browse files
Files changed (2) hide show
  1. app.py +56 -80
  2. requirements.txt +4 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from flask import Flask, render_template, request
2
- # import whisper
3
  import tempfile
4
  import os
5
  import time
@@ -7,7 +7,7 @@ import torch
7
  import numpy as np
8
  import requests
9
  from tqdm import tqdm
10
- from transformers import BertTokenizer, AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
11
  from model.multi_class_model import MultiClassModel # Adjust if needed
12
 
13
  app = Flask(__name__)
@@ -49,38 +49,58 @@ model = MultiClassModel.load_from_checkpoint(
49
  )
50
  model.eval()
51
 
52
- # === INITIAL SETUP: Whisper Pipeline ===
53
- # https://huggingface.co/openai/whisper-large-v3
54
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
55
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
56
-
57
- model_id = "openai/whisper-large-v3"
58
-
59
- whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
60
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
61
- )
62
- whisper_model.to(device)
63
-
64
- processor = AutoProcessor.from_pretrained(model_id)
65
-
66
- pipe = pipeline(
67
- "automatic-speech-recognition",
68
- model=whisper_model,
69
- tokenizer=processor.tokenizer,
70
- feature_extractor=processor.feature_extractor,
71
- chunk_length_s=30,
72
- batch_size=128, # batch size for inference - set based on your device
73
- torch_dtype=torch_dtype,
74
- device=device,
75
- max_new_tokens=128, # Limit text generation
76
- return_timestamps=False, # Save memory
77
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
 
 
 
 
79
 
80
- def whisper_api(temp_audio_path):
81
- result = pipe(temp_audio_path, generate_kwargs={"language": "indonesian", "task": "transcribe"})
82
- print(result["text"])
83
- return result
84
 
85
 
86
  # === ROUTES ===
@@ -108,35 +128,11 @@ def transcribe():
108
  temp_audio_path = temp_audio.name
109
 
110
  # Step 1: Transcribe
111
- # transcription = whisper_model.transcribe(temp_audio_path, language="id")
112
- transcription = whisper_api(temp_audio_path)
113
  os.remove(temp_audio_path)
114
- transcribed_text = transcription["text"]
115
 
116
  # Step 2: BERT Prediction
117
- encoding = tokenizer.encode_plus(
118
- transcribed_text,
119
- add_special_tokens=True,
120
- max_length=512,
121
- return_token_type_ids=True,
122
- padding="max_length",
123
- return_attention_mask=True,
124
- return_tensors='pt',
125
- )
126
-
127
- with torch.no_grad():
128
- prediction = model(
129
- encoding["input_ids"],
130
- encoding["attention_mask"],
131
- encoding["token_type_ids"]
132
- )
133
-
134
- logits = prediction
135
- probabilities = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
136
- predicted_class = np.argmax(probabilities)
137
- predicted_label = AGE_LABELS[predicted_class]
138
-
139
- prob_results = [(label, f"{prob:.4f}") for label, prob in zip(AGE_LABELS, probabilities)]
140
 
141
  # Stop timer
142
  end_time = time.time()
@@ -167,28 +163,8 @@ def predict_text():
167
  # Start timer
168
  start_time = time.time()
169
 
170
- encoding = tokenizer.encode_plus(
171
- user_lyrics,
172
- add_special_tokens=True,
173
- max_length=512,
174
- return_token_type_ids=True,
175
- padding="max_length",
176
- return_attention_mask=True,
177
- return_tensors='pt',
178
- )
179
-
180
- with torch.no_grad():
181
- prediction = model(
182
- encoding["input_ids"],
183
- encoding["attention_mask"],
184
- encoding["token_type_ids"]
185
- )
186
-
187
- logits = prediction
188
- probabilities = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
189
- predicted_class = np.argmax(probabilities)
190
- predicted_label = AGE_LABELS[predicted_class]
191
- prob_results = [(label, f"{prob:.4f}") for label, prob in zip(AGE_LABELS, probabilities)]
192
 
193
  # End timer
194
  end_time = time.time()
 
1
  from flask import Flask, render_template, request
2
+ from faster_whisper import WhisperModel
3
  import tempfile
4
  import os
5
  import time
 
7
  import numpy as np
8
  import requests
9
  from tqdm import tqdm
10
+ from transformers import BertTokenizer
11
  from model.multi_class_model import MultiClassModel # Adjust if needed
12
 
13
  app = Flask(__name__)
 
49
  )
50
  model.eval()
51
 
52
+ # === INITIAL SETUP: Faster Whisper ===
53
+ # https://github.com/SYSTRAN/faster-whisper
54
+ faster_whisper_model_size = "large-v3"
55
+
56
+ # Run on GPU with FP16
57
+ # model = WhisperModel(model_size, device="cuda", compute_type="float16")
58
+ # or run on GPU with INT8
59
+ # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
60
+ # or run on CPU with INT8
61
+ faster_whisper_model = WhisperModel(faster_whisper_model_size, device="cpu", compute_type="int8")
62
+
63
+
64
+ def faster_whisper(temp_audio_path):
65
+ segments, info = faster_whisper_model.transcribe(
66
+ temp_audio_path,
67
+ language="id",
68
+ beam_size=1 # Lower beam_size, faster but may miss words
69
+ )
70
+
71
+ print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
72
+
73
+ for segment in segments:
74
+ print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
75
+
76
+ return segment.text
77
+
78
+
79
+ def bert_predict(input_lyric):
80
+ encoding = tokenizer.encode_plus(
81
+ input_lyric,
82
+ add_special_tokens=True,
83
+ max_length=512,
84
+ return_token_type_ids=True,
85
+ padding="max_length",
86
+ return_attention_mask=True,
87
+ return_tensors='pt',
88
+ )
89
+
90
+ with torch.no_grad():
91
+ prediction = model(
92
+ encoding["input_ids"],
93
+ encoding["attention_mask"],
94
+ encoding["token_type_ids"]
95
+ )
96
 
97
+ logits = prediction
98
+ probabilities = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
99
+ predicted_class = np.argmax(probabilities)
100
+ predicted_label = AGE_LABELS[predicted_class]
101
 
102
+ prob_results = [(label, f"{prob:.4f}") for label, prob in zip(AGE_LABELS, probabilities)]
103
+ return predicted_label, prob_results
 
 
104
 
105
 
106
  # === ROUTES ===
 
128
  temp_audio_path = temp_audio.name
129
 
130
  # Step 1: Transcribe
131
+ transcribed_text = faster_whisper(temp_audio_path)
 
132
  os.remove(temp_audio_path)
 
133
 
134
  # Step 2: BERT Prediction
135
+ predicted_label, prob_results = bert_predict(transcribed_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # Stop timer
138
  end_time = time.time()
 
163
  # Start timer
164
  start_time = time.time()
165
 
166
+ # Step 1: BERT Prediction
167
+ predicted_label, prob_results = bert_predict(user_lyrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # End timer
170
  end_time = time.time()
requirements.txt CHANGED
@@ -7,12 +7,13 @@ Jinja2==2.11.3
7
  MarkupSafe==1.1.1
8
  SQLAlchemy==1.3.22
9
  Werkzeug==1.0.1
10
- openai-whisper
11
- setuptools-rust
 
12
  # ffmpeg
13
  # ffmpeg-python
14
  # imageio[ffmpeg]
15
- accelerate
16
  pytorch-lightning==2.2.1
17
  lightning==2.4.0
18
  torch==2.2.0
 
7
  MarkupSafe==1.1.1
8
  SQLAlchemy==1.3.22
9
  Werkzeug==1.0.1
10
+ faster_whisper
11
+ # openai-whisper
12
+ # setuptools-rust
13
  # ffmpeg
14
  # ffmpeg-python
15
  # imageio[ffmpeg]
16
+ # accelerate
17
  pytorch-lightning==2.2.1
18
  lightning==2.4.0
19
  torch==2.2.0