Concat first prompt with initial prompt
Browse files- app.py +19 -3
- src/vad.py +5 -2
    	
        app.py
    CHANGED
    
    | @@ -89,9 +89,17 @@ class WhisperTranscriber: | |
| 89 |  | 
| 90 | 
             
                def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None, 
         | 
| 91 | 
             
                                    vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 92 | 
             
                    # Callable for processing an audio file
         | 
| 93 | 
            -
                    whisperCallable = lambda audio, prompt, detected_language : model.transcribe(audio, \
         | 
| 94 | 
            -
                             language=language if language else detected_language, task=task,  | 
|  | |
|  | |
| 95 |  | 
| 96 | 
             
                    # The results
         | 
| 97 | 
             
                    if (vad == 'silero-vad'):
         | 
| @@ -113,10 +121,18 @@ class WhisperTranscriber: | |
| 113 | 
             
                        result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
         | 
| 114 | 
             
                    else:
         | 
| 115 | 
             
                        # Default VAD
         | 
| 116 | 
            -
                        result = whisperCallable(audio_path, None, None)
         | 
| 117 |  | 
| 118 | 
             
                    return result
         | 
| 119 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 120 | 
             
                def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
         | 
| 121 | 
             
                    # Use Silero VAD 
         | 
| 122 | 
             
                    if (self.vad_model is None):
         | 
|  | |
| 89 |  | 
| 90 | 
             
                def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None, 
         | 
| 91 | 
             
                                    vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
         | 
| 92 | 
            +
                    
         | 
| 93 | 
            +
                    initial_prompt = decodeOptions.pop('initial_prompt', None)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    if ('task' in decodeOptions):
         | 
| 96 | 
            +
                        task = decodeOptions.pop('task')
         | 
| 97 | 
            +
             | 
| 98 | 
             
                    # Callable for processing an audio file
         | 
| 99 | 
            +
                    whisperCallable = lambda audio, segment_index, prompt, detected_language : model.transcribe(audio, \
         | 
| 100 | 
            +
                             language=language if language else detected_language, task=task, \
         | 
| 101 | 
            +
                             initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
         | 
| 102 | 
            +
                             **decodeOptions)
         | 
| 103 |  | 
| 104 | 
             
                    # The results
         | 
| 105 | 
             
                    if (vad == 'silero-vad'):
         | 
|  | |
| 121 | 
             
                        result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
         | 
| 122 | 
             
                    else:
         | 
| 123 | 
             
                        # Default VAD
         | 
| 124 | 
            +
                        result = whisperCallable(audio_path, 0, None, None)
         | 
| 125 |  | 
| 126 | 
             
                    return result
         | 
| 127 |  | 
| 128 | 
            +
                def _concat_prompt(self, prompt1, prompt2):
         | 
| 129 | 
            +
                    if (prompt1 is None):
         | 
| 130 | 
            +
                        return prompt2
         | 
| 131 | 
            +
                    elif (prompt2 is None):
         | 
| 132 | 
            +
                        return prompt1
         | 
| 133 | 
            +
                    else:
         | 
| 134 | 
            +
                        return prompt1 + " " + prompt2
         | 
| 135 | 
            +
             | 
| 136 | 
             
                def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
         | 
| 137 | 
             
                    # Use Silero VAD 
         | 
| 138 | 
             
                    if (self.vad_model is None):
         | 
    	
        src/vad.py
    CHANGED
    
    | @@ -100,7 +100,7 @@ class AbstractTranscription(ABC): | |
| 100 | 
             
                    audio: str
         | 
| 101 | 
             
                        The audio file.
         | 
| 102 |  | 
| 103 | 
            -
                    whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], str, str], dict[str, Union[dict, Any]]]
         | 
| 104 | 
             
                        The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer, 
         | 
| 105 | 
             
                        the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
         | 
| 106 |  | 
| @@ -147,8 +147,11 @@ class AbstractTranscription(ABC): | |
| 147 | 
             
                    languageCounter = Counter()
         | 
| 148 | 
             
                    detected_language = None
         | 
| 149 |  | 
|  | |
|  | |
| 150 | 
             
                    # For each time segment, run whisper
         | 
| 151 | 
             
                    for segment in merged:
         | 
|  | |
| 152 | 
             
                        segment_start = segment['start']
         | 
| 153 | 
             
                        segment_end = segment['end']
         | 
| 154 | 
             
                        segment_expand_amount = segment.get('expand_amount', 0)
         | 
| @@ -169,7 +172,7 @@ class AbstractTranscription(ABC): | |
| 169 |  | 
| 170 | 
             
                        print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", 
         | 
| 171 | 
             
                              segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
         | 
| 172 | 
            -
                        segment_result = whisperCallable(segment_audio, segment_prompt, detected_language)
         | 
| 173 |  | 
| 174 | 
             
                        adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
         | 
| 175 |  | 
|  | |
| 100 | 
             
                    audio: str
         | 
| 101 | 
             
                        The audio file.
         | 
| 102 |  | 
| 103 | 
            +
                    whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], int, str, str], dict[str, Union[dict, Any]]]
         | 
| 104 | 
             
                        The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer, 
         | 
| 105 | 
             
                        the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
         | 
| 106 |  | 
|  | |
| 147 | 
             
                    languageCounter = Counter()
         | 
| 148 | 
             
                    detected_language = None
         | 
| 149 |  | 
| 150 | 
            +
                    segment_index = -1
         | 
| 151 | 
            +
             | 
| 152 | 
             
                    # For each time segment, run whisper
         | 
| 153 | 
             
                    for segment in merged:
         | 
| 154 | 
            +
                        segment_index += 1
         | 
| 155 | 
             
                        segment_start = segment['start']
         | 
| 156 | 
             
                        segment_end = segment['end']
         | 
| 157 | 
             
                        segment_expand_amount = segment.get('expand_amount', 0)
         | 
|  | |
| 172 |  | 
| 173 | 
             
                        print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", 
         | 
| 174 | 
             
                              segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
         | 
| 175 | 
            +
                        segment_result = whisperCallable(segment_audio, segment_index, segment_prompt, detected_language)
         | 
| 176 |  | 
| 177 | 
             
                        adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
         | 
| 178 |  |