sanchit-gandhi commited on
Commit
1923ff8
·
1 Parent(s): c220da3

back to batched (b2b)

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -24,7 +24,6 @@ language_names = sorted(TO_LANGUAGE_CODE.keys())
24
  CHUNK_LENGTH_S = 30
25
  BATCH_SIZE = 16
26
  NUM_PROC = 8
27
- SAMPLING_RATE = 16000
28
  FILE_LIMIT_MB = 1000
29
 
30
 
@@ -71,7 +70,10 @@ def forward(batch, task=None, return_timestamps=False):
71
 
72
 
73
  if __name__ == "__main__":
74
- def transcribe_audio(microphone, file_upload, task, return_timestamps):
 
 
 
75
  warn_output = ""
76
  if (microphone is not None) and (file_upload is not None):
77
  warn_output = (
@@ -80,19 +82,31 @@ if __name__ == "__main__":
80
  )
81
 
82
  elif (microphone is None) and (file_upload is None):
83
- return "ERROR: You have to either use the microphone or upload an audio file"
84
 
85
  inputs = microphone if microphone is not None else file_upload
86
 
 
 
 
 
87
  with open(inputs, "rb") as f:
88
  inputs = f.read()
89
 
90
- inputs = ffmpeg_read(inputs, SAMPLING_RATE)
91
- inputs = {"array": base64.b64encode(inputs.tobytes()).decode(), "sampling_rate": SAMPLING_RATE}
 
 
92
 
93
- text, timestamps = inference(inputs=inputs, task=task, return_timestamps=return_timestamps)
 
 
 
 
94
 
95
- return warn_output + text, timestamps
 
 
96
 
97
  def _return_yt_html_embed(yt_url):
98
  video_id = yt_url.split("?v=")[-1]
@@ -110,7 +124,7 @@ if __name__ == "__main__":
110
  return html_embed_str, text, timestamps
111
 
112
  audio_chunked = gr.Interface(
113
- fn=transcribe_audio,
114
  inputs=[
115
  gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
116
  gr.inputs.Audio(source="upload", optional=True, type="filepath"),
@@ -152,5 +166,5 @@ if __name__ == "__main__":
152
  with demo:
153
  gr.TabbedInterface([audio_chunked, youtube], ["Transcribe Audio", "Transcribe YouTube"])
154
 
155
- demo.queue(concurrency_count=5, max_size=10)
156
  demo.launch()
 
24
  CHUNK_LENGTH_S = 30
25
  BATCH_SIZE = 16
26
  NUM_PROC = 8
 
27
  FILE_LIMIT_MB = 1000
28
 
29
 
 
70
 
71
 
72
  if __name__ == "__main__":
73
+ processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2")
74
+ pool = Pool(NUM_PROC)
75
+
76
+ def transcribe_chunked_audio(microphone, file_upload, task, return_timestamps):
77
  warn_output = ""
78
  if (microphone is not None) and (file_upload is not None):
79
  warn_output = (
 
82
  )
83
 
84
  elif (microphone is None) and (file_upload is None):
85
+ return "ERROR: You have to either use the microphone or upload an audio file", None
86
 
87
  inputs = microphone if microphone is not None else file_upload
88
 
89
+ file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
90
+ if file_size_mb > FILE_LIMIT_MB:
91
+ return f"ERROR: File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.", None
92
+
93
  with open(inputs, "rb") as f:
94
  inputs = f.read()
95
 
96
+ inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
97
+ inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
98
+
99
+ dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
100
 
101
+ try:
102
+ model_outputs = pool.map(partial(forward, task=task, return_timestamps=return_timestamps), dataloader)
103
+ except ValueError as err:
104
+ # pre-processor does all the necessary compatibility checks for our audio inputs
105
+ return err, None
106
 
107
+ post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)
108
+ timestamps = post_processed.get("chunks")
109
+ return warn_output + post_processed["text"], timestamps
110
 
111
  def _return_yt_html_embed(yt_url):
112
  video_id = yt_url.split("?v=")[-1]
 
124
  return html_embed_str, text, timestamps
125
 
126
  audio_chunked = gr.Interface(
127
+ fn=transcribe_chunked_audio,
128
  inputs=[
129
  gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
130
  gr.inputs.Audio(source="upload", optional=True, type="filepath"),
 
166
  with demo:
167
  gr.TabbedInterface([audio_chunked, youtube], ["Transcribe Audio", "Transcribe YouTube"])
168
 
169
+ demo.queue(concurrency_count=3, max_size=10)
170
  demo.launch()