Spaces:
Build error
Build error
Add an extra interface for performing diarization
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from datetime import datetime
|
| 2 |
import json
|
| 3 |
import math
|
| 4 |
-
from typing import Iterator, Union
|
| 5 |
import argparse
|
| 6 |
|
| 7 |
from io import StringIO
|
|
@@ -16,14 +16,14 @@ import torch
|
|
| 16 |
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
| 17 |
from src.diarization.diarization import Diarization
|
| 18 |
from src.diarization.diarizationContainer import DiarizationContainer
|
|
|
|
| 19 |
from src.hooks.progressListener import ProgressListener
|
| 20 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
| 21 |
-
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
| 22 |
from src.languages import get_language_names
|
| 23 |
from src.modelCache import ModelCache
|
| 24 |
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
| 25 |
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
| 26 |
-
from src.source import get_audio_source_collection
|
| 27 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
| 28 |
|
| 29 |
# External programs
|
|
@@ -101,7 +101,8 @@ class WhisperTranscriber:
|
|
| 101 |
self.diarization_kwargs = kwargs
|
| 102 |
|
| 103 |
def unset_diarization(self):
|
| 104 |
-
self.diarization
|
|
|
|
| 105 |
self.diarization_kwargs = None
|
| 106 |
|
| 107 |
# Entry function for the simple tab
|
|
@@ -185,19 +186,59 @@ class WhisperTranscriber:
|
|
| 185 |
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
| 186 |
progress=progress)
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
| 189 |
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
|
|
|
| 190 |
**decodeOptions: dict):
|
| 191 |
try:
|
| 192 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
| 193 |
|
|
|
|
|
|
|
|
|
|
| 194 |
try:
|
| 195 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
| 196 |
selectedModel = modelName if modelName is not None else "base"
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
# Result
|
| 203 |
download = []
|
|
@@ -234,8 +275,12 @@ class WhisperTranscriber:
|
|
| 234 |
sub_task_start=current_progress,
|
| 235 |
sub_task_total=source_audio_duration)
|
| 236 |
|
| 237 |
-
# Transcribe
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
| 240 |
|
| 241 |
# Update progress
|
|
@@ -363,6 +408,10 @@ class WhisperTranscriber:
|
|
| 363 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
| 364 |
|
| 365 |
# Diarization
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
if self.diarization and self.diarization_kwargs:
|
| 367 |
print("Diarizing ", audio_path)
|
| 368 |
diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
|
|
@@ -373,9 +422,9 @@ class WhisperTranscriber:
|
|
| 373 |
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
| 374 |
|
| 375 |
# Add speakers to result
|
| 376 |
-
|
| 377 |
|
| 378 |
-
return
|
| 379 |
|
| 380 |
def _create_progress_listener(self, progress: gr.Progress):
|
| 381 |
if (progress is None):
|
|
@@ -449,7 +498,7 @@ class WhisperTranscriber:
|
|
| 449 |
os.makedirs(output_dir)
|
| 450 |
|
| 451 |
text = result["text"]
|
| 452 |
-
language = result["language"]
|
| 453 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
| 454 |
|
| 455 |
print("Max line width " + str(languageMaxLineWidth))
|
|
@@ -635,7 +684,25 @@ def create_ui(app_config: ApplicationConfig):
|
|
| 635 |
gr.Text(label="Segments")
|
| 636 |
])
|
| 637 |
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
# Queue up the demo
|
| 641 |
if is_queue_mode:
|
|
|
|
| 1 |
from datetime import datetime
|
| 2 |
import json
|
| 3 |
import math
|
| 4 |
+
from typing import Callable, Iterator, Union
|
| 5 |
import argparse
|
| 6 |
|
| 7 |
from io import StringIO
|
|
|
|
| 16 |
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
| 17 |
from src.diarization.diarization import Diarization
|
| 18 |
from src.diarization.diarizationContainer import DiarizationContainer
|
| 19 |
+
from src.diarization.transcriptLoader import load_transcript
|
| 20 |
from src.hooks.progressListener import ProgressListener
|
| 21 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
|
|
|
| 22 |
from src.languages import get_language_names
|
| 23 |
from src.modelCache import ModelCache
|
| 24 |
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
| 25 |
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
| 26 |
+
from src.source import AudioSource, get_audio_source_collection
|
| 27 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
| 28 |
|
| 29 |
# External programs
|
|
|
|
| 101 |
self.diarization_kwargs = kwargs
|
| 102 |
|
| 103 |
def unset_diarization(self):
|
| 104 |
+
if self.diarization is not None:
|
| 105 |
+
self.diarization.cleanup()
|
| 106 |
self.diarization_kwargs = None
|
| 107 |
|
| 108 |
# Entry function for the simple tab
|
|
|
|
| 186 |
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
| 187 |
progress=progress)
|
| 188 |
|
| 189 |
+
# Perform diarization given a specific input audio file and whisper file
|
| 190 |
+
def perform_extra(self, languageName, urlData, singleFile, whisper_file: str,
|
| 191 |
+
highlight_words: bool = False,
|
| 192 |
+
diarization: bool = False, diarization_speakers: int = 2, diarization_min_speakers = 1, diarization_max_speakers = 5, progress=gr.Progress()):
|
| 193 |
+
|
| 194 |
+
if whisper_file is None:
|
| 195 |
+
raise ValueError("whisper_file is required")
|
| 196 |
+
|
| 197 |
+
# Set diarization
|
| 198 |
+
if diarization:
|
| 199 |
+
self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
|
| 200 |
+
min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
|
| 201 |
+
else:
|
| 202 |
+
self.unset_diarization()
|
| 203 |
+
|
| 204 |
+
def custom_transcribe_file(source: AudioSource):
|
| 205 |
+
result = load_transcript(whisper_file.name)
|
| 206 |
+
|
| 207 |
+
# Set language if not set
|
| 208 |
+
if not "language" in result:
|
| 209 |
+
result["language"] = languageName
|
| 210 |
+
|
| 211 |
+
# Mark speakers
|
| 212 |
+
result = self._handle_diarization(source.source_path, result)
|
| 213 |
+
return result
|
| 214 |
+
|
| 215 |
+
multipleFiles = [singleFile] if singleFile else None
|
| 216 |
+
|
| 217 |
+
# Will return download, text, vtt
|
| 218 |
+
return self.transcribe_webui("base", "", urlData, multipleFiles, None, None, None,
|
| 219 |
+
progress=progress,highlight_words=highlight_words,
|
| 220 |
+
override_transcribe_file=custom_transcribe_file, override_max_sources=1)
|
| 221 |
+
|
| 222 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
| 223 |
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
| 224 |
+
override_transcribe_file: Callable[[AudioSource], dict] = None, override_max_sources = None,
|
| 225 |
**decodeOptions: dict):
|
| 226 |
try:
|
| 227 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
| 228 |
|
| 229 |
+
if override_max_sources is not None and len(sources) > override_max_sources:
|
| 230 |
+
raise ValueError("Maximum number of sources is " + str(override_max_sources) + ", but " + str(len(sources)) + " were provided")
|
| 231 |
+
|
| 232 |
try:
|
| 233 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
| 234 |
selectedModel = modelName if modelName is not None else "base"
|
| 235 |
|
| 236 |
+
if override_transcribe_file is None:
|
| 237 |
+
model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
|
| 238 |
+
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
| 239 |
+
cache=self.model_cache, models=self.app_config.models)
|
| 240 |
+
else:
|
| 241 |
+
model = None
|
| 242 |
|
| 243 |
# Result
|
| 244 |
download = []
|
|
|
|
| 275 |
sub_task_start=current_progress,
|
| 276 |
sub_task_total=source_audio_duration)
|
| 277 |
|
| 278 |
+
# Transcribe using the override function if specified
|
| 279 |
+
if override_transcribe_file is None:
|
| 280 |
+
result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
|
| 281 |
+
else:
|
| 282 |
+
result = override_transcribe_file(source)
|
| 283 |
+
|
| 284 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
| 285 |
|
| 286 |
# Update progress
|
|
|
|
| 408 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
| 409 |
|
| 410 |
# Diarization
|
| 411 |
+
result = self._handle_diarization(audio_path, result)
|
| 412 |
+
return result
|
| 413 |
+
|
| 414 |
+
def _handle_diarization(self, audio_path: str, input: dict):
|
| 415 |
if self.diarization and self.diarization_kwargs:
|
| 416 |
print("Diarizing ", audio_path)
|
| 417 |
diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
|
|
|
|
| 422 |
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
| 423 |
|
| 424 |
# Add speakers to result
|
| 425 |
+
input = self.diarization.mark_speakers(diarization_result, input)
|
| 426 |
|
| 427 |
+
return input
|
| 428 |
|
| 429 |
def _create_progress_listener(self, progress: gr.Progress):
|
| 430 |
if (progress is None):
|
|
|
|
| 498 |
os.makedirs(output_dir)
|
| 499 |
|
| 500 |
text = result["text"]
|
| 501 |
+
language = result["language"] if "language" in result else None
|
| 502 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
| 503 |
|
| 504 |
print("Max line width " + str(languageMaxLineWidth))
|
|
|
|
| 684 |
gr.Text(label="Segments")
|
| 685 |
])
|
| 686 |
|
| 687 |
+
perform_extra_interface = gr.Interface(fn=ui.perform_extra,
|
| 688 |
+
description="Perform additional processing on a given JSON or SRT file", article=ui_article, inputs=[
|
| 689 |
+
gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
|
| 690 |
+
gr.Text(label="URL (YouTube, etc.)"),
|
| 691 |
+
gr.File(label="Upload Audio File", file_count="single"),
|
| 692 |
+
gr.File(label="Upload JSON/SRT File", file_count="single"),
|
| 693 |
+
gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
|
| 694 |
+
|
| 695 |
+
*common_diarization_inputs(),
|
| 696 |
+
gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
|
| 697 |
+
gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
|
| 698 |
+
|
| 699 |
+
], outputs=[
|
| 700 |
+
gr.File(label="Download"),
|
| 701 |
+
gr.Text(label="Transcription"),
|
| 702 |
+
gr.Text(label="Segments")
|
| 703 |
+
])
|
| 704 |
+
|
| 705 |
+
demo = gr.TabbedInterface([simple_transcribe, full_transcribe, perform_extra_interface], tab_names=["Simple", "Full", "Extra"])
|
| 706 |
|
| 707 |
# Queue up the demo
|
| 708 |
if is_queue_mode:
|
cli.py
CHANGED
|
@@ -108,12 +108,12 @@ def cli():
|
|
| 108 |
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
| 109 |
|
| 110 |
# Diarization
|
| 111 |
-
parser.add_argument('--auth_token', type=str, default=
|
| 112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
| 113 |
help="whether to perform speaker diarization")
|
| 114 |
-
parser.add_argument("--diarization_num_speakers", type=int, default=
|
| 115 |
-
parser.add_argument("--diarization_min_speakers", type=int, default=
|
| 116 |
-
parser.add_argument("--diarization_max_speakers", type=int, default=
|
| 117 |
|
| 118 |
args = parser.parse_args().__dict__
|
| 119 |
model_name: str = args.pop("model")
|
|
|
|
| 108 |
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
| 109 |
|
| 110 |
# Diarization
|
| 111 |
+
parser.add_argument('--auth_token', type=str, default=app_config.auth_token, help='HuggingFace API Token (optional)')
|
| 112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
| 113 |
help="whether to perform speaker diarization")
|
| 114 |
+
parser.add_argument("--diarization_num_speakers", type=int, default=app_config.diarization_speakers, help="Number of speakers")
|
| 115 |
+
parser.add_argument("--diarization_min_speakers", type=int, default=app_config.diarization_min_speakers, help="Minimum number of speakers")
|
| 116 |
+
parser.add_argument("--diarization_max_speakers", type=int, default=app_config.diarization_max_speakers, help="Maximum number of speakers")
|
| 117 |
|
| 118 |
args = parser.parse_args().__dict__
|
| 119 |
model_name: str = args.pop("model")
|