Spaces:
Runtime error
Runtime error
Fix diarization in CLI
Browse files
app.py
CHANGED
|
@@ -240,19 +240,6 @@ class WhisperTranscriber:
|
|
| 240 |
# Update progress
|
| 241 |
current_progress += source_audio_duration
|
| 242 |
|
| 243 |
-
# Diarization
|
| 244 |
-
if self.diarization and self.diarization_kwargs:
|
| 245 |
-
print("Diarizing ", source.source_path)
|
| 246 |
-
diarization_result = list(self.diarization.run(source.source_path, **self.diarization_kwargs))
|
| 247 |
-
|
| 248 |
-
# Print result
|
| 249 |
-
print("Diarization result: ")
|
| 250 |
-
for entry in diarization_result:
|
| 251 |
-
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
| 252 |
-
|
| 253 |
-
# Add speakers to result
|
| 254 |
-
result = self.diarization.mark_speakers(diarization_result, result)
|
| 255 |
-
|
| 256 |
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
| 257 |
|
| 258 |
if len(sources) > 1:
|
|
@@ -373,6 +360,19 @@ class WhisperTranscriber:
|
|
| 373 |
else:
|
| 374 |
# Default VAD
|
| 375 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
return result
|
| 378 |
|
|
|
|
| 240 |
# Update progress
|
| 241 |
current_progress += source_audio_duration
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
| 244 |
|
| 245 |
if len(sources) > 1:
|
|
|
|
| 360 |
else:
|
| 361 |
# Default VAD
|
| 362 |
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
| 363 |
+
|
| 364 |
+
# Diarization
|
| 365 |
+
if self.diarization and self.diarization_kwargs:
|
| 366 |
+
print("Diarizing ", audio_path)
|
| 367 |
+
diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
|
| 368 |
+
|
| 369 |
+
# Print result
|
| 370 |
+
print("Diarization result: ")
|
| 371 |
+
for entry in diarization_result:
|
| 372 |
+
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
|
| 373 |
+
|
| 374 |
+
# Add speakers to result
|
| 375 |
+
result = self.diarization.mark_speakers(diarization_result, result)
|
| 376 |
|
| 377 |
return result
|
| 378 |
|
cli.py
CHANGED
|
@@ -111,9 +111,9 @@ def cli():
|
|
| 111 |
parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
|
| 112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
| 113 |
help="whether to perform speaker diarization")
|
| 114 |
-
parser.add_argument("--
|
| 115 |
-
parser.add_argument("--
|
| 116 |
-
parser.add_argument("--
|
| 117 |
|
| 118 |
args = parser.parse_args().__dict__
|
| 119 |
model_name: str = args.pop("model")
|
|
@@ -151,11 +151,11 @@ def cli():
|
|
| 151 |
compute_type = args.pop("compute_type")
|
| 152 |
highlight_words = args.pop("highlight_words")
|
| 153 |
|
| 154 |
-
diarization = args.pop("diarization")
|
| 155 |
auth_token = args.pop("auth_token")
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
| 159 |
|
| 160 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
| 161 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
|
|
|
| 111 |
parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
|
| 112 |
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
|
| 113 |
help="whether to perform speaker diarization")
|
| 114 |
+
parser.add_argument("--diarization_num_speakers", type=int, default=None, help="Number of speakers")
|
| 115 |
+
parser.add_argument("--diarization_min_speakers", type=int, default=None, help="Minimum number of speakers")
|
| 116 |
+
parser.add_argument("--diarization_max_speakers", type=int, default=None, help="Maximum number of speakers")
|
| 117 |
|
| 118 |
args = parser.parse_args().__dict__
|
| 119 |
model_name: str = args.pop("model")
|
|
|
|
| 151 |
compute_type = args.pop("compute_type")
|
| 152 |
highlight_words = args.pop("highlight_words")
|
| 153 |
|
|
|
|
| 154 |
auth_token = args.pop("auth_token")
|
| 155 |
+
diarization = args.pop("diarization")
|
| 156 |
+
num_speakers = args.pop("diarization_num_speakers")
|
| 157 |
+
min_speakers = args.pop("diarization_min_speakers")
|
| 158 |
+
max_speakers = args.pop("diarization_max_speakers")
|
| 159 |
|
| 160 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
| 161 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|