Spaces:
Build error
Build error
Support progress for multiple devices
Browse files- app.py +2 -2
- src/vad.py +78 -66
- src/vadParallel.py +50 -8
app.py
CHANGED
|
@@ -279,7 +279,6 @@ class WhisperTranscriber:
|
|
| 279 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
| 280 |
return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
|
| 281 |
|
| 282 |
-
# TODO: Handle progress listener
|
| 283 |
gpu_devices = self.parallel_device_list
|
| 284 |
|
| 285 |
if (gpu_devices is None or len(gpu_devices) == 0):
|
|
@@ -297,7 +296,8 @@ class WhisperTranscriber:
|
|
| 297 |
parallel_vad = ParallelTranscription()
|
| 298 |
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
| 299 |
config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
|
| 300 |
-
cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context
|
|
|
|
| 301 |
|
| 302 |
def _has_parallel_devices(self):
|
| 303 |
return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
|
|
|
|
| 279 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
| 280 |
return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
|
| 281 |
|
|
|
|
| 282 |
gpu_devices = self.parallel_device_list
|
| 283 |
|
| 284 |
if (gpu_devices is None or len(gpu_devices) == 0):
|
|
|
|
| 296 |
parallel_vad = ParallelTranscription()
|
| 297 |
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
| 298 |
config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
|
| 299 |
+
cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
|
| 300 |
+
progress_listener=progressListener)
|
| 301 |
|
| 302 |
def _has_parallel_devices(self):
|
| 303 |
return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
|
src/vad.py
CHANGED
|
@@ -153,84 +153,96 @@ class AbstractTranscription(ABC):
|
|
| 153 |
A list of start and end timestamps, in fractional seconds.
|
| 154 |
"""
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
|
|
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
result = {
|
| 169 |
-
'text': "",
|
| 170 |
-
'segments': [],
|
| 171 |
-
'language': ""
|
| 172 |
-
}
|
| 173 |
-
languageCounter = Counter()
|
| 174 |
-
detected_language = None
|
| 175 |
-
|
| 176 |
-
segment_index = config.initial_segment_index
|
| 177 |
-
|
| 178 |
-
# For each time segment, run whisper
|
| 179 |
-
for segment in merged:
|
| 180 |
-
segment_index += 1
|
| 181 |
-
segment_start = segment['start']
|
| 182 |
-
segment_end = segment['end']
|
| 183 |
-
segment_expand_amount = segment.get('expand_amount', 0)
|
| 184 |
-
segment_gap = segment.get('gap', False)
|
| 185 |
-
|
| 186 |
-
segment_duration = segment_end - segment_start
|
| 187 |
-
|
| 188 |
-
if segment_duration < MIN_SEGMENT_DURATION:
|
| 189 |
-
continue
|
| 190 |
-
|
| 191 |
-
# Audio to run on Whisper
|
| 192 |
-
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
| 193 |
-
# Previous segments to use as a prompt
|
| 194 |
-
segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
|
| 195 |
-
|
| 196 |
-
# Detected language
|
| 197 |
-
detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
|
| 198 |
-
|
| 199 |
-
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
| 200 |
-
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
| 201 |
-
|
| 202 |
-
scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=max_audio_duration, sub_task_start=segment_start, sub_task_total=segment_duration)
|
| 203 |
-
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
|
| 204 |
-
|
| 205 |
-
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
adjusted_segment_end = adjusted_segment['end']
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
|
| 218 |
-
#
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
|
| 223 |
-
if not segment_gap:
|
| 224 |
-
languageCounter[segment_result['language']] += 1
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
if detected_language is not None:
|
| 230 |
-
result['language'] = detected_language
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
return result
|
| 233 |
|
|
|
|
|
|
|
|
|
|
| 234 |
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
| 235 |
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
| 236 |
# Add segments to the current prompt window (unless it is a speech gap)
|
|
|
|
| 153 |
A list of start and end timestamps, in fractional seconds.
|
| 154 |
"""
|
| 155 |
|
| 156 |
+
try:
|
| 157 |
+
max_audio_duration = self.get_audio_duration(audio, config)
|
| 158 |
+
timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
|
| 159 |
|
| 160 |
+
# Get speech timestamps from full audio file
|
| 161 |
+
merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
|
| 162 |
|
| 163 |
+
# A deque of transcribed segments that is passed to the next segment as a prompt
|
| 164 |
+
prompt_window = deque()
|
| 165 |
|
| 166 |
+
print("Processing timestamps:")
|
| 167 |
+
pprint(merged)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
result = {
|
| 170 |
+
'text': "",
|
| 171 |
+
'segments': [],
|
| 172 |
+
'language': ""
|
| 173 |
+
}
|
| 174 |
+
languageCounter = Counter()
|
| 175 |
+
detected_language = None
|
| 176 |
|
| 177 |
+
segment_index = config.initial_segment_index
|
|
|
|
| 178 |
|
| 179 |
+
# Calculate progress
|
| 180 |
+
progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0
|
| 181 |
+
progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged])
|
| 182 |
|
| 183 |
+
# For each time segment, run whisper
|
| 184 |
+
for segment in merged:
|
| 185 |
+
segment_index += 1
|
| 186 |
+
segment_start = segment['start']
|
| 187 |
+
segment_end = segment['end']
|
| 188 |
+
segment_expand_amount = segment.get('expand_amount', 0)
|
| 189 |
+
segment_gap = segment.get('gap', False)
|
| 190 |
|
| 191 |
+
segment_duration = segment_end - segment_start
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
if segment_duration < MIN_SEGMENT_DURATION:
|
| 194 |
+
continue
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
+
# Audio to run on Whisper
|
| 197 |
+
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
| 198 |
+
# Previous segments to use as a prompt
|
| 199 |
+
segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
|
| 200 |
+
|
| 201 |
+
# Detected language
|
| 202 |
+
detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
|
| 203 |
+
|
| 204 |
+
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
| 205 |
+
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
| 206 |
+
|
| 207 |
+
scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=progress_total_duration,
|
| 208 |
+
sub_task_start=segment_start - progress_start_offset, sub_task_total=segment_duration)
|
| 209 |
+
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
|
| 210 |
+
|
| 211 |
+
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
| 212 |
+
|
| 213 |
+
# Propagate expand amount to the segments
|
| 214 |
+
if (segment_expand_amount > 0):
|
| 215 |
+
segment_without_expansion = segment_duration - segment_expand_amount
|
| 216 |
+
|
| 217 |
+
for adjusted_segment in adjusted_segments:
|
| 218 |
+
adjusted_segment_end = adjusted_segment['end']
|
| 219 |
+
|
| 220 |
+
# Add expand amount if the segment got expanded
|
| 221 |
+
if (adjusted_segment_end > segment_without_expansion):
|
| 222 |
+
adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
|
| 223 |
+
|
| 224 |
+
# Append to output
|
| 225 |
+
result['text'] += segment_result['text']
|
| 226 |
+
result['segments'].extend(adjusted_segments)
|
| 227 |
+
|
| 228 |
+
# Increment detected language
|
| 229 |
+
if not segment_gap:
|
| 230 |
+
languageCounter[segment_result['language']] += 1
|
| 231 |
+
|
| 232 |
+
# Update prompt window
|
| 233 |
+
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
|
| 234 |
+
|
| 235 |
+
if detected_language is not None:
|
| 236 |
+
result['language'] = detected_language
|
| 237 |
+
finally:
|
| 238 |
+
# Notify progress listener that we are done
|
| 239 |
+
if progressListener is not None:
|
| 240 |
+
progressListener.on_finished()
|
| 241 |
return result
|
| 242 |
|
| 243 |
+
def get_audio_duration(self, audio: str, config: TranscriptionConfig):
|
| 244 |
+
return get_audio_duration(audio)
|
| 245 |
+
|
| 246 |
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
| 247 |
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
| 248 |
# Add segments to the current prompt window (unless it is a speech gap)
|
src/vadParallel.py
CHANGED
|
@@ -1,14 +1,33 @@
|
|
| 1 |
import multiprocessing
|
|
|
|
| 2 |
import threading
|
| 3 |
import time
|
|
|
|
| 4 |
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
| 5 |
from src.whisperContainer import WhisperCallback
|
| 6 |
|
| 7 |
-
from multiprocessing import Pool
|
| 8 |
|
| 9 |
-
from typing import Any, Dict, List
|
| 10 |
import os
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class ParallelContext:
|
| 14 |
def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
|
|
@@ -86,7 +105,8 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 86 |
super().__init__(sampling_rate=sampling_rate)
|
| 87 |
|
| 88 |
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
|
| 89 |
-
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None
|
|
|
|
| 90 |
total_duration = get_audio_duration(audio)
|
| 91 |
|
| 92 |
# First, get the timestamps for the original audio
|
|
@@ -108,6 +128,9 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 108 |
parameters = []
|
| 109 |
segment_index = config.initial_segment_index
|
| 110 |
|
|
|
|
|
|
|
|
|
|
| 111 |
for i in range(len(gpu_devices)):
|
| 112 |
# Note that device_segment_list can be empty. But we will still create a process for it,
|
| 113 |
# as otherwise we run the risk of assigning the same device to multiple processes.
|
|
@@ -120,7 +143,8 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 120 |
device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
|
| 121 |
segment_index += len(device_segment_list)
|
| 122 |
|
| 123 |
-
|
|
|
|
| 124 |
|
| 125 |
merged = {
|
| 126 |
'text': '',
|
|
@@ -142,7 +166,24 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 142 |
pool = gpu_parallel_context.get_pool()
|
| 143 |
|
| 144 |
# Run the transcription in parallel
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
for result in results:
|
| 148 |
# Merge the results
|
|
@@ -231,11 +272,12 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 231 |
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
|
| 232 |
# Override timestamps that will be processed
|
| 233 |
if (config.override_timestamps is not None):
|
| 234 |
-
print("Using override timestamps of size " + str(len(config.override_timestamps)))
|
| 235 |
return config.override_timestamps
|
| 236 |
return super().get_merged_timestamps(timestamps, config, total_duration)
|
| 237 |
|
| 238 |
-
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig
|
|
|
|
| 239 |
# Override device ID the first time
|
| 240 |
if (os.environ.get("INITIALIZED", None) is None):
|
| 241 |
os.environ["INITIALIZED"] = "1"
|
|
@@ -246,7 +288,7 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 246 |
print("Using device " + config.device_id)
|
| 247 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
| 248 |
|
| 249 |
-
return super().transcribe(audio, whisperCallable, config)
|
| 250 |
|
| 251 |
def _split(self, a, n):
|
| 252 |
"""Split a list into n approximately equal parts."""
|
|
|
|
| 1 |
import multiprocessing
|
| 2 |
+
from queue import Empty
|
| 3 |
import threading
|
| 4 |
import time
|
| 5 |
+
from src.hooks.whisperProgressHook import ProgressListener
|
| 6 |
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
| 7 |
from src.whisperContainer import WhisperCallback
|
| 8 |
|
| 9 |
+
from multiprocessing import Pool, Queue
|
| 10 |
|
| 11 |
+
from typing import Any, Dict, List, Union
|
| 12 |
import os
|
| 13 |
|
| 14 |
+
class _ProgressListenerToQueue(ProgressListener):
|
| 15 |
+
def __init__(self, progress_queue: Queue):
|
| 16 |
+
self.progress_queue = progress_queue
|
| 17 |
+
self.progress_total = 0
|
| 18 |
+
self.prev_progress = 0
|
| 19 |
+
|
| 20 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
| 21 |
+
delta = current - self.prev_progress
|
| 22 |
+
self.prev_progress = current
|
| 23 |
+
self.progress_total = total
|
| 24 |
+
self.progress_queue.put(delta)
|
| 25 |
+
|
| 26 |
+
def on_finished(self):
|
| 27 |
+
if self.progress_total > self.prev_progress:
|
| 28 |
+
delta = self.progress_total - self.prev_progress
|
| 29 |
+
self.progress_queue.put(delta)
|
| 30 |
+
self.prev_progress = self.progress_total
|
| 31 |
|
| 32 |
class ParallelContext:
|
| 33 |
def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
|
|
|
|
| 105 |
super().__init__(sampling_rate=sampling_rate)
|
| 106 |
|
| 107 |
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
|
| 108 |
+
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
|
| 109 |
+
progress_listener: ProgressListener = None):
|
| 110 |
total_duration = get_audio_duration(audio)
|
| 111 |
|
| 112 |
# First, get the timestamps for the original audio
|
|
|
|
| 128 |
parameters = []
|
| 129 |
segment_index = config.initial_segment_index
|
| 130 |
|
| 131 |
+
processing_manager = multiprocessing.Manager()
|
| 132 |
+
progress_queue = processing_manager.Queue()
|
| 133 |
+
|
| 134 |
for i in range(len(gpu_devices)):
|
| 135 |
# Note that device_segment_list can be empty. But we will still create a process for it,
|
| 136 |
# as otherwise we run the risk of assigning the same device to multiple processes.
|
|
|
|
| 143 |
device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
|
| 144 |
segment_index += len(device_segment_list)
|
| 145 |
|
| 146 |
+
progress_listener_to_queue = _ProgressListenerToQueue(progress_queue)
|
| 147 |
+
parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]);
|
| 148 |
|
| 149 |
merged = {
|
| 150 |
'text': '',
|
|
|
|
| 166 |
pool = gpu_parallel_context.get_pool()
|
| 167 |
|
| 168 |
# Run the transcription in parallel
|
| 169 |
+
results_async = pool.starmap_async(self.transcribe, parameters)
|
| 170 |
+
total_progress = 0
|
| 171 |
+
|
| 172 |
+
while not results_async.ready():
|
| 173 |
+
try:
|
| 174 |
+
delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds
|
| 175 |
+
except Empty:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
total_progress += delta
|
| 179 |
+
if progress_listener is not None:
|
| 180 |
+
progress_listener.on_progress(total_progress, total_duration)
|
| 181 |
+
|
| 182 |
+
results = results_async.get()
|
| 183 |
+
|
| 184 |
+
# Call the finished callback
|
| 185 |
+
if progress_listener is not None:
|
| 186 |
+
progress_listener.on_finished()
|
| 187 |
|
| 188 |
for result in results:
|
| 189 |
# Merge the results
|
|
|
|
| 272 |
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
|
| 273 |
# Override timestamps that will be processed
|
| 274 |
if (config.override_timestamps is not None):
|
| 275 |
+
print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps)))
|
| 276 |
return config.override_timestamps
|
| 277 |
return super().get_merged_timestamps(timestamps, config, total_duration)
|
| 278 |
|
| 279 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig,
|
| 280 |
+
progressListener: ProgressListener = None):
|
| 281 |
# Override device ID the first time
|
| 282 |
if (os.environ.get("INITIALIZED", None) is None):
|
| 283 |
os.environ["INITIALIZED"] = "1"
|
|
|
|
| 288 |
print("Using device " + config.device_id)
|
| 289 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
| 290 |
|
| 291 |
+
return super().transcribe(audio, whisperCallable, config, progressListener)
|
| 292 |
|
| 293 |
def _split(self, a, n):
|
| 294 |
"""Split a list into n approximately equal parts."""
|