Adding JSON initial prompt
Browse filesBy selecting "json_prompt_mode", you can
customize the prompt to each segment.
For instance:
[
{"segment_index": 0, "prompt": "Hello, how are you?"},
{"segment_index": 1, "prompt": "I'm doing well, how are you?"},
{"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
]
- app.py +16 -4
- cli.py +2 -2
- src/config.py +5 -0
- src/prompts/abstractPromptStrategy.py +73 -0
- src/prompts/jsonPromptStrategy.py +48 -0
- src/prompts/prependPromptStrategy.py +31 -0
- src/whisper/abstractWhisperContainer.py +9 -24
- src/whisper/fasterWhisperContainer.py +14 -12
- src/whisper/whisperContainer.py +18 -13
app.py
CHANGED
|
@@ -13,12 +13,14 @@ import numpy as np
|
|
| 13 |
|
| 14 |
import torch
|
| 15 |
|
| 16 |
-
from src.config import ApplicationConfig, VadInitialPromptMode
|
| 17 |
from src.hooks.progressListener import ProgressListener
|
| 18 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
| 19 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
| 20 |
from src.languages import get_language_names
|
| 21 |
from src.modelCache import ModelCache
|
|
|
|
|
|
|
| 22 |
from src.source import get_audio_source_collection
|
| 23 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
| 24 |
|
|
@@ -271,8 +273,18 @@ class WhisperTranscriber:
|
|
| 271 |
if ('task' in decodeOptions):
|
| 272 |
task = decodeOptions.pop('task')
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
# Callable for processing an audio file
|
| 275 |
-
whisperCallable = model.create_callback(language, task,
|
| 276 |
|
| 277 |
# The results
|
| 278 |
if (vadOptions.vad == 'silero-vad'):
|
|
@@ -519,7 +531,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
| 519 |
*common_vad_inputs(),
|
| 520 |
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
|
| 521 |
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
|
| 522 |
-
gr.Dropdown(choices=
|
| 523 |
|
| 524 |
*common_word_timestamps_inputs(),
|
| 525 |
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
|
@@ -580,7 +592,7 @@ if __name__ == '__main__':
|
|
| 580 |
help="The default model name.") # medium
|
| 581 |
parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
|
| 582 |
help="The default VAD.") # silero-vad
|
| 583 |
-
parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=
|
| 584 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
| 585 |
parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
|
| 586 |
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
|
|
|
| 13 |
|
| 14 |
import torch
|
| 15 |
|
| 16 |
+
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
| 17 |
from src.hooks.progressListener import ProgressListener
|
| 18 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
| 19 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
| 20 |
from src.languages import get_language_names
|
| 21 |
from src.modelCache import ModelCache
|
| 22 |
+
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
| 23 |
+
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
| 24 |
from src.source import get_audio_source_collection
|
| 25 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
| 26 |
|
|
|
|
| 273 |
if ('task' in decodeOptions):
|
| 274 |
task = decodeOptions.pop('task')
|
| 275 |
|
| 276 |
+
if (vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
|
| 277 |
+
vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
| 278 |
+
# Prepend initial prompt
|
| 279 |
+
prompt_strategy = PrependPromptStrategy(initial_prompt, vadOptions.vadInitialPromptMode)
|
| 280 |
+
elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
|
| 281 |
+
# Use a JSON format to specify the prompt for each segment
|
| 282 |
+
prompt_strategy = JsonPromptStrategy(initial_prompt)
|
| 283 |
+
else:
|
| 284 |
+
raise ValueError("Invalid vadInitialPromptMode: " + vadOptions.vadInitialPromptMode)
|
| 285 |
+
|
| 286 |
# Callable for processing an audio file
|
| 287 |
+
whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
|
| 288 |
|
| 289 |
# The results
|
| 290 |
if (vadOptions.vad == 'silero-vad'):
|
|
|
|
| 531 |
*common_vad_inputs(),
|
| 532 |
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
|
| 533 |
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
|
| 534 |
+
gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode"),
|
| 535 |
|
| 536 |
*common_word_timestamps_inputs(),
|
| 537 |
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
|
|
|
| 592 |
help="The default model name.") # medium
|
| 593 |
parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
|
| 594 |
help="The default VAD.") # silero-vad
|
| 595 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
|
| 596 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
| 597 |
parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
|
| 598 |
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
cli.py
CHANGED
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from app import VadOptions, WhisperTranscriber
|
| 10 |
-
from src.config import ApplicationConfig, VadInitialPromptMode
|
| 11 |
from src.download import download_url
|
| 12 |
from src.languages import get_language_names
|
| 13 |
|
|
@@ -47,7 +47,7 @@ def cli():
|
|
| 47 |
|
| 48 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
| 49 |
help="The voice activity detection algorithm to use") # silero-vad
|
| 50 |
-
parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=
|
| 51 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
| 52 |
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
|
| 53 |
help="The window size (in seconds) to merge voice segments")
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from app import VadOptions, WhisperTranscriber
|
| 10 |
+
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
| 11 |
from src.download import download_url
|
| 12 |
from src.languages import get_language_names
|
| 13 |
|
|
|
|
| 47 |
|
| 48 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
| 49 |
help="The voice activity detection algorithm to use") # silero-vad
|
| 50 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
|
| 51 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
| 52 |
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
|
| 53 |
help="The window size (in seconds) to merge voice segments")
|
src/config.py
CHANGED
|
@@ -24,9 +24,12 @@ class ModelConfig:
|
|
| 24 |
self.path = path
|
| 25 |
self.type = type
|
| 26 |
|
|
|
|
|
|
|
| 27 |
class VadInitialPromptMode(Enum):
|
| 28 |
PREPEND_ALL_SEGMENTS = 1
|
| 29 |
PREPREND_FIRST_SEGMENT = 2
|
|
|
|
| 30 |
|
| 31 |
@staticmethod
|
| 32 |
def from_string(s: str):
|
|
@@ -36,6 +39,8 @@ class VadInitialPromptMode(Enum):
|
|
| 36 |
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
| 37 |
elif normalized == "prepend_first_segment":
|
| 38 |
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
|
|
|
|
|
|
| 39 |
else:
|
| 40 |
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
|
| 41 |
|
|
|
|
| 24 |
self.path = path
|
| 25 |
self.type = type
|
| 26 |
|
| 27 |
+
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
| 28 |
+
|
| 29 |
class VadInitialPromptMode(Enum):
|
| 30 |
PREPEND_ALL_SEGMENTS = 1
|
| 31 |
PREPREND_FIRST_SEGMENT = 2
|
| 32 |
+
JSON_PROMPT_MODE = 3
|
| 33 |
|
| 34 |
@staticmethod
|
| 35 |
def from_string(s: str):
|
|
|
|
| 39 |
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
| 40 |
elif normalized == "prepend_first_segment":
|
| 41 |
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
| 42 |
+
elif normalized == "json_prompt_mode":
|
| 43 |
+
return VadInitialPromptMode.JSON_PROMPT_MODE
|
| 44 |
else:
|
| 45 |
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
|
| 46 |
|
src/prompts/abstractPromptStrategy.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class AbstractPromptStrategy:
|
| 5 |
+
"""
|
| 6 |
+
Represents a strategy for generating prompts for a given audio segment.
|
| 7 |
+
|
| 8 |
+
Note that the strategy must be picklable, as it will be serialized and sent to the workers.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
@abc.abstractmethod
|
| 12 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
| 13 |
+
"""
|
| 14 |
+
Retrieves the prompt for a given segment.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
segment_index: int
|
| 19 |
+
The index of the segment.
|
| 20 |
+
whisper_prompt: str
|
| 21 |
+
The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
|
| 22 |
+
detected_language: str
|
| 23 |
+
The language detected for the segment.
|
| 24 |
+
"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@abc.abstractmethod
|
| 28 |
+
def on_segment_finished(self, segment_index: int, whisper_prompt: str, detected_language: str, result: dict):
|
| 29 |
+
"""
|
| 30 |
+
Called when a segment has finished processing.
|
| 31 |
+
|
| 32 |
+
Parameters
|
| 33 |
+
----------
|
| 34 |
+
segment_index: int
|
| 35 |
+
The index of the segment.
|
| 36 |
+
whisper_prompt: str
|
| 37 |
+
The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
|
| 38 |
+
detected_language: str
|
| 39 |
+
The language detected for the segment.
|
| 40 |
+
result: dict
|
| 41 |
+
The result of the segment. It has the following format:
|
| 42 |
+
{
|
| 43 |
+
"text": str,
|
| 44 |
+
"segments": [
|
| 45 |
+
{
|
| 46 |
+
"text": str,
|
| 47 |
+
"start": float,
|
| 48 |
+
"end": float,
|
| 49 |
+
"words": [words],
|
| 50 |
+
}
|
| 51 |
+
],
|
| 52 |
+
"language": str,
|
| 53 |
+
}
|
| 54 |
+
"""
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
def _concat_prompt(self, prompt1, prompt2):
|
| 58 |
+
"""
|
| 59 |
+
Concatenates two prompts.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
prompt1: str
|
| 64 |
+
The first prompt.
|
| 65 |
+
prompt2: str
|
| 66 |
+
The second prompt.
|
| 67 |
+
"""
|
| 68 |
+
if (prompt1 is None):
|
| 69 |
+
return prompt2
|
| 70 |
+
elif (prompt2 is None):
|
| 71 |
+
return prompt1
|
| 72 |
+
else:
|
| 73 |
+
return prompt1 + " " + prompt2
|
src/prompts/jsonPromptStrategy.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class JsonPromptSegment():
|
| 6 |
+
def __init__(self, segment_index: int, prompt: str, format_prompt: bool = False):
|
| 7 |
+
self.prompt = prompt
|
| 8 |
+
self.segment_index = segment_index
|
| 9 |
+
self.format_prompt = format_prompt
|
| 10 |
+
|
| 11 |
+
class JsonPromptStrategy(AbstractPromptStrategy):
|
| 12 |
+
def __init__(self, initial_json_prompt: str):
|
| 13 |
+
"""
|
| 14 |
+
Parameters
|
| 15 |
+
----------
|
| 16 |
+
initial_json_prompt: str
|
| 17 |
+
The initial prompts for each segment in JSON form.
|
| 18 |
+
|
| 19 |
+
Format:
|
| 20 |
+
[
|
| 21 |
+
{"segment_index": 0, "prompt": "Hello, how are you?"},
|
| 22 |
+
{"segment_index": 1, "prompt": "I'm doing well, how are you?"},
|
| 23 |
+
{"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
parsed_json = json.loads(initial_json_prompt)
|
| 28 |
+
self.segment_lookup = dict[str, JsonPromptSegment]()
|
| 29 |
+
|
| 30 |
+
for prompt_entry in parsed_json:
|
| 31 |
+
segment_index = prompt_entry["segment_index"]
|
| 32 |
+
prompt = prompt_entry["prompt"]
|
| 33 |
+
format_prompt = prompt_entry.get("format_prompt", False)
|
| 34 |
+
self.segment_lookup[str(segment_index)] = JsonPromptSegment(segment_index, prompt, format_prompt)
|
| 35 |
+
|
| 36 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
| 37 |
+
# Lookup prompt
|
| 38 |
+
prompt = self.segment_lookup.get(str(segment_index), None)
|
| 39 |
+
|
| 40 |
+
if (prompt is None):
|
| 41 |
+
# No prompt found, return whisper prompt
|
| 42 |
+
print(f"Could not find prompt for segment {segment_index}, returning whisper prompt")
|
| 43 |
+
return whisper_prompt
|
| 44 |
+
|
| 45 |
+
if (prompt.format_prompt):
|
| 46 |
+
return prompt.prompt.format(whisper_prompt)
|
| 47 |
+
else:
|
| 48 |
+
return self._concat_prompt(prompt.prompt, whisper_prompt)
|
src/prompts/prependPromptStrategy.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.config import VadInitialPromptMode
|
| 2 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
| 3 |
+
|
| 4 |
+
class PrependPromptStrategy(AbstractPromptStrategy):
|
| 5 |
+
"""
|
| 6 |
+
A simple prompt strategy that prepends a single prompt to all segments of audio, or prepends the prompt to the first segment of audio.
|
| 7 |
+
"""
|
| 8 |
+
def __init__(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode):
|
| 9 |
+
"""
|
| 10 |
+
Parameters
|
| 11 |
+
----------
|
| 12 |
+
initial_prompt: str
|
| 13 |
+
The initial prompt to use for the transcription.
|
| 14 |
+
initial_prompt_mode: VadInitialPromptMode
|
| 15 |
+
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
| 16 |
+
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
| 17 |
+
"""
|
| 18 |
+
self.initial_prompt = initial_prompt
|
| 19 |
+
self.initial_prompt_mode = initial_prompt_mode
|
| 20 |
+
|
| 21 |
+
# This is a simple prompt strategy, so we only support these two modes
|
| 22 |
+
if initial_prompt_mode not in [VadInitialPromptMode.PREPEND_ALL_SEGMENTS, VadInitialPromptMode.PREPREND_FIRST_SEGMENT]:
|
| 23 |
+
raise ValueError(f"Unsupported initial prompt mode {initial_prompt_mode}")
|
| 24 |
+
|
| 25 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
| 26 |
+
if (self.initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
|
| 27 |
+
return self._concat_prompt(self.initial_prompt, whisper_prompt)
|
| 28 |
+
elif (self.initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
| 29 |
+
return self._concat_prompt(self.initial_prompt, whisper_prompt) if segment_index == 0 else whisper_prompt
|
| 30 |
+
else:
|
| 31 |
+
raise ValueError(f"Unknown initial prompt mode {self.initial_prompt_mode}")
|
src/whisper/abstractWhisperContainer.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
| 1 |
import abc
|
| 2 |
from typing import List
|
|
|
|
| 3 |
from src.config import ModelConfig, VadInitialPromptMode
|
| 4 |
|
| 5 |
from src.hooks.progressListener import ProgressListener
|
| 6 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
|
| 7 |
|
| 8 |
class AbstractWhisperCallback:
|
|
|
|
|
|
|
|
|
|
| 9 |
@abc.abstractmethod
|
| 10 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
| 11 |
"""
|
|
@@ -24,23 +29,6 @@ class AbstractWhisperCallback:
|
|
| 24 |
"""
|
| 25 |
raise NotImplementedError()
|
| 26 |
|
| 27 |
-
def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
|
| 28 |
-
prompt: str, segment_index: int):
|
| 29 |
-
if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
|
| 30 |
-
return self._concat_prompt(initial_prompt, prompt)
|
| 31 |
-
elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
| 32 |
-
return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
|
| 33 |
-
else:
|
| 34 |
-
raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
|
| 35 |
-
|
| 36 |
-
def _concat_prompt(self, prompt1, prompt2):
|
| 37 |
-
if (prompt1 is None):
|
| 38 |
-
return prompt2
|
| 39 |
-
elif (prompt2 is None):
|
| 40 |
-
return prompt1
|
| 41 |
-
else:
|
| 42 |
-
return prompt1 + " " + prompt2
|
| 43 |
-
|
| 44 |
class AbstractWhisperContainer:
|
| 45 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
| 46 |
download_root: str = None,
|
|
@@ -75,8 +63,8 @@ class AbstractWhisperContainer:
|
|
| 75 |
pass
|
| 76 |
|
| 77 |
@abc.abstractmethod
|
| 78 |
-
def create_callback(self, language: str = None, task: str = None,
|
| 79 |
-
|
| 80 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
| 81 |
"""
|
| 82 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
@@ -87,11 +75,8 @@ class AbstractWhisperContainer:
|
|
| 87 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 88 |
task: str
|
| 89 |
The task - either translate or transcribe.
|
| 90 |
-
|
| 91 |
-
The
|
| 92 |
-
initial_prompt_mode: VadInitialPromptMode
|
| 93 |
-
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
| 94 |
-
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
| 95 |
decodeOptions: dict
|
| 96 |
Additional options to pass to the decoder. Must be pickleable.
|
| 97 |
|
|
|
|
| 1 |
import abc
|
| 2 |
from typing import List
|
| 3 |
+
|
| 4 |
from src.config import ModelConfig, VadInitialPromptMode
|
| 5 |
|
| 6 |
from src.hooks.progressListener import ProgressListener
|
| 7 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
| 8 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
| 9 |
|
| 10 |
class AbstractWhisperCallback:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.__prompt_mode_gpt = None
|
| 13 |
+
|
| 14 |
@abc.abstractmethod
|
| 15 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
| 16 |
"""
|
|
|
|
| 29 |
"""
|
| 30 |
raise NotImplementedError()
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
class AbstractWhisperContainer:
|
| 33 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
| 34 |
download_root: str = None,
|
|
|
|
| 63 |
pass
|
| 64 |
|
| 65 |
@abc.abstractmethod
|
| 66 |
+
def create_callback(self, language: str = None, task: str = None,
|
| 67 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
| 68 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
| 69 |
"""
|
| 70 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
|
|
| 75 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 76 |
task: str
|
| 77 |
The task - either translate or transcribe.
|
| 78 |
+
prompt_strategy: AbstractPromptStrategy
|
| 79 |
+
The prompt strategy to use for the transcription.
|
|
|
|
|
|
|
|
|
|
| 80 |
decodeOptions: dict
|
| 81 |
Additional options to pass to the decoder. Must be pickleable.
|
| 82 |
|
src/whisper/fasterWhisperContainer.py
CHANGED
|
@@ -6,6 +6,7 @@ from src.config import ModelConfig, VadInitialPromptMode
|
|
| 6 |
from src.hooks.progressListener import ProgressListener
|
| 7 |
from src.languages import get_language_from_name
|
| 8 |
from src.modelCache import ModelCache
|
|
|
|
| 9 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
| 10 |
from src.utils import format_timestamp
|
| 11 |
|
|
@@ -56,8 +57,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
| 56 |
model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
|
| 57 |
return model
|
| 58 |
|
| 59 |
-
def create_callback(self, language: str = None, task: str = None,
|
| 60 |
-
|
| 61 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
| 62 |
"""
|
| 63 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
@@ -68,11 +69,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
| 68 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 69 |
task: str
|
| 70 |
The task - either translate or transcribe.
|
| 71 |
-
|
| 72 |
-
The
|
| 73 |
-
initial_prompt_mode: VadInitialPromptMode
|
| 74 |
-
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
| 75 |
-
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
| 76 |
decodeOptions: dict
|
| 77 |
Additional options to pass to the decoder. Must be pickleable.
|
| 78 |
|
|
@@ -80,17 +78,16 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
| 80 |
-------
|
| 81 |
A WhisperCallback object.
|
| 82 |
"""
|
| 83 |
-
return FasterWhisperCallback(self, language=language, task=task,
|
| 84 |
|
| 85 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
| 86 |
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
|
| 87 |
-
|
| 88 |
**decodeOptions: dict):
|
| 89 |
self.model_container = model_container
|
| 90 |
self.language = language
|
| 91 |
self.task = task
|
| 92 |
-
self.
|
| 93 |
-
self.initial_prompt_mode = initial_prompt_mode
|
| 94 |
self.decodeOptions = decodeOptions
|
| 95 |
|
| 96 |
self._printed_warning = False
|
|
@@ -138,7 +135,8 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
| 138 |
# See if supress_tokens is a string - if so, convert it to a list of ints
|
| 139 |
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
| 140 |
|
| 141 |
-
initial_prompt = self.
|
|
|
|
| 142 |
|
| 143 |
segments_generator, info = model.transcribe(audio, \
|
| 144 |
language=language_code if language_code else detected_language, task=self.task, \
|
|
@@ -184,6 +182,10 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
| 184 |
"duration": info.duration if info else None
|
| 185 |
}
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
if progress_listener is not None:
|
| 188 |
progress_listener.on_finished()
|
| 189 |
return result
|
|
|
|
| 6 |
from src.hooks.progressListener import ProgressListener
|
| 7 |
from src.languages import get_language_from_name
|
| 8 |
from src.modelCache import ModelCache
|
| 9 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
| 10 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
| 11 |
from src.utils import format_timestamp
|
| 12 |
|
|
|
|
| 57 |
model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
|
| 58 |
return model
|
| 59 |
|
| 60 |
+
def create_callback(self, language: str = None, task: str = None,
|
| 61 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
| 62 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
| 63 |
"""
|
| 64 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
|
|
| 69 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 70 |
task: str
|
| 71 |
The task - either translate or transcribe.
|
| 72 |
+
prompt_strategy: AbstractPromptStrategy
|
| 73 |
+
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
|
|
|
|
|
|
|
|
|
|
| 74 |
decodeOptions: dict
|
| 75 |
Additional options to pass to the decoder. Must be pickleable.
|
| 76 |
|
|
|
|
| 78 |
-------
|
| 79 |
A WhisperCallback object.
|
| 80 |
"""
|
| 81 |
+
return FasterWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
| 82 |
|
| 83 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
| 84 |
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
|
| 85 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
| 86 |
**decodeOptions: dict):
|
| 87 |
self.model_container = model_container
|
| 88 |
self.language = language
|
| 89 |
self.task = task
|
| 90 |
+
self.prompt_strategy = prompt_strategy
|
|
|
|
| 91 |
self.decodeOptions = decodeOptions
|
| 92 |
|
| 93 |
self._printed_warning = False
|
|
|
|
| 135 |
# See if supress_tokens is a string - if so, convert it to a list of ints
|
| 136 |
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
| 137 |
|
| 138 |
+
initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
|
| 139 |
+
if self.prompt_strategy else prompt
|
| 140 |
|
| 141 |
segments_generator, info = model.transcribe(audio, \
|
| 142 |
language=language_code if language_code else detected_language, task=self.task, \
|
|
|
|
| 182 |
"duration": info.duration if info else None
|
| 183 |
}
|
| 184 |
|
| 185 |
+
# If we have a prompt strategy, we need to increment the current prompt
|
| 186 |
+
if self.prompt_strategy:
|
| 187 |
+
self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
|
| 188 |
+
|
| 189 |
if progress_listener is not None:
|
| 190 |
progress_listener.on_finished()
|
| 191 |
return result
|
src/whisper/whisperContainer.py
CHANGED
|
@@ -15,6 +15,7 @@ from src.config import ModelConfig, VadInitialPromptMode
|
|
| 15 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
| 16 |
|
| 17 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
|
| 18 |
from src.utils import download_file
|
| 19 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
| 20 |
|
|
@@ -69,8 +70,8 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
| 69 |
|
| 70 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
| 71 |
|
| 72 |
-
def create_callback(self, language: str = None, task: str = None,
|
| 73 |
-
|
| 74 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
| 75 |
"""
|
| 76 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
@@ -81,11 +82,8 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
| 81 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 82 |
task: str
|
| 83 |
The task - either translate or transcribe.
|
| 84 |
-
|
| 85 |
-
The
|
| 86 |
-
initial_prompt_mode: VadInitialPromptMode
|
| 87 |
-
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
| 88 |
-
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
| 89 |
decodeOptions: dict
|
| 90 |
Additional options to pass to the decoder. Must be pickleable.
|
| 91 |
|
|
@@ -93,7 +91,7 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
| 93 |
-------
|
| 94 |
A WhisperCallback object.
|
| 95 |
"""
|
| 96 |
-
return WhisperCallback(self, language=language, task=task,
|
| 97 |
|
| 98 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
| 99 |
from src.conversion.hf_converter import convert_hf_whisper
|
|
@@ -162,13 +160,14 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
| 162 |
return model_config.path
|
| 163 |
|
| 164 |
class WhisperCallback(AbstractWhisperCallback):
|
| 165 |
-
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
|
| 166 |
-
|
|
|
|
| 167 |
self.model_container = model_container
|
| 168 |
self.language = language
|
| 169 |
self.task = task
|
| 170 |
-
self.
|
| 171 |
-
|
| 172 |
self.decodeOptions = decodeOptions
|
| 173 |
|
| 174 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
|
@@ -201,11 +200,17 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
| 201 |
if self.model_container.compute_type in ["fp16", "float16"]:
|
| 202 |
decodeOptions["fp16"] = True
|
| 203 |
|
| 204 |
-
initial_prompt = self.
|
|
|
|
| 205 |
|
| 206 |
result = model.transcribe(audio, \
|
| 207 |
language=self.language if self.language else detected_language, task=self.task, \
|
| 208 |
initial_prompt=initial_prompt, \
|
| 209 |
**decodeOptions
|
| 210 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
return result
|
|
|
|
| 15 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
| 16 |
|
| 17 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
| 18 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
| 19 |
from src.utils import download_file
|
| 20 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
| 21 |
|
|
|
|
| 70 |
|
| 71 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
| 72 |
|
| 73 |
+
def create_callback(self, language: str = None, task: str = None,
|
| 74 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
| 75 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
| 76 |
"""
|
| 77 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
|
|
| 82 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 83 |
task: str
|
| 84 |
The task - either translate or transcribe.
|
| 85 |
+
prompt_strategy: AbstractPromptStrategy
|
| 86 |
+
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
|
|
|
|
|
|
|
|
|
|
| 87 |
decodeOptions: dict
|
| 88 |
Additional options to pass to the decoder. Must be pickleable.
|
| 89 |
|
|
|
|
| 91 |
-------
|
| 92 |
A WhisperCallback object.
|
| 93 |
"""
|
| 94 |
+
return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
| 95 |
|
| 96 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
| 97 |
from src.conversion.hf_converter import convert_hf_whisper
|
|
|
|
| 160 |
return model_config.path
|
| 161 |
|
| 162 |
class WhisperCallback(AbstractWhisperCallback):
|
| 163 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
|
| 164 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
| 165 |
+
**decodeOptions: dict):
|
| 166 |
self.model_container = model_container
|
| 167 |
self.language = language
|
| 168 |
self.task = task
|
| 169 |
+
self.prompt_strategy = prompt_strategy
|
| 170 |
+
|
| 171 |
self.decodeOptions = decodeOptions
|
| 172 |
|
| 173 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
|
|
|
| 200 |
if self.model_container.compute_type in ["fp16", "float16"]:
|
| 201 |
decodeOptions["fp16"] = True
|
| 202 |
|
| 203 |
+
initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
|
| 204 |
+
if self.prompt_strategy else prompt
|
| 205 |
|
| 206 |
result = model.transcribe(audio, \
|
| 207 |
language=self.language if self.language else detected_language, task=self.task, \
|
| 208 |
initial_prompt=initial_prompt, \
|
| 209 |
**decodeOptions
|
| 210 |
)
|
| 211 |
+
|
| 212 |
+
# If we have a prompt strategy, we need to increment the current prompt
|
| 213 |
+
if self.prompt_strategy:
|
| 214 |
+
self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
|
| 215 |
+
|
| 216 |
return result
|