Saiyaswanth007 commited on
Commit
b9dea2c
·
1 Parent(s): d65b6e8

Fixing gradio RealStream

Browse files
Files changed (1) hide show
  1. app.py +241 -554
app.py CHANGED
@@ -1,24 +1,18 @@
1
  import gradio as gr
2
  import numpy as np
3
- import queue
4
  import torch
 
5
  import time
6
- import threading
7
  import os
8
  import urllib.request
9
- import torchaudio
10
  from scipy.spatial.distance import cosine
11
- import json
 
 
12
  import asyncio
13
- from typing import Iterator
14
- import logging
15
 
16
- # Configure logging
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
-
20
- # Simplified configuration parameters
21
- SILENCE_THRESHS = [0, 0.4]
22
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
23
  FINAL_BEAM_SIZE = 5
24
  REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
@@ -35,33 +29,10 @@ EMBEDDING_HISTORY_SIZE = 5
35
  MIN_SEGMENT_DURATION = 1.0
36
  DEFAULT_MAX_SPEAKERS = 4
37
  ABSOLUTE_MAX_SPEAKERS = 10
38
-
39
- # Global variables
40
- FAST_SENTENCE_END = True
41
  SAMPLE_RATE = 16000
42
- BUFFER_SIZE = 1024
43
- CHANNELS = 1
44
- CHUNK_DURATION_MS = 100 # 100ms chunks for FastRTC
45
-
46
- # Speaker colors
47
- SPEAKER_COLORS = [
48
- "#FFFF00", # Yellow
49
- "#FF0000", # Red
50
- "#00FF00", # Green
51
- "#00FFFF", # Cyan
52
- "#FF00FF", # Magenta
53
- "#0000FF", # Blue
54
- "#FF8000", # Orange
55
- "#00FF80", # Spring Green
56
- "#8000FF", # Purple
57
- "#FFFFFF", # White
58
- ]
59
-
60
- SPEAKER_COLOR_NAMES = [
61
- "Yellow", "Red", "Green", "Cyan", "Magenta",
62
- "Blue", "Orange", "Spring Green", "Purple", "White"
63
- ]
64
 
 
 
65
 
66
  class SpeechBrainEncoder:
67
  """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
@@ -73,24 +44,11 @@ class SpeechBrainEncoder:
73
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
74
  os.makedirs(self.cache_dir, exist_ok=True)
75
 
76
- def _download_model(self):
77
- """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
78
- model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
79
- model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
80
-
81
- if not os.path.exists(model_path):
82
- logger.info(f"Downloading ECAPA-TDNN model to {model_path}...")
83
- urllib.request.urlretrieve(model_url, model_path)
84
-
85
- return model_path
86
-
87
  def load_model(self):
88
  """Load the ECAPA-TDNN model"""
89
  try:
90
  from speechbrain.pretrained import EncoderClassifier
91
 
92
- model_path = self._download_model()
93
-
94
  self.model = EncoderClassifier.from_hparams(
95
  source="speechbrain/spkrec-ecapa-voxceleb",
96
  savedir=self.cache_dir,
@@ -100,7 +58,7 @@ class SpeechBrainEncoder:
100
  self.model_loaded = True
101
  return True
102
  except Exception as e:
103
- logger.error(f"Error loading ECAPA-TDNN model: {e}")
104
  return False
105
 
106
  def embed_utterance(self, audio, sr=16000):
@@ -122,31 +80,12 @@ class SpeechBrainEncoder:
122
 
123
  return embedding.squeeze().cpu().numpy()
124
  except Exception as e:
125
- logger.error(f"Error extracting embedding: {e}")
126
  return np.zeros(self.embedding_dim)
127
 
128
 
129
- class AudioProcessor:
130
- """Processes audio data to extract speaker embeddings"""
131
- def __init__(self, encoder):
132
- self.encoder = encoder
133
-
134
- def extract_embedding(self, audio_float):
135
- try:
136
- # Ensure audio is in the right format
137
- if np.abs(audio_float).max() > 1.0:
138
- audio_float = audio_float / np.abs(audio_float).max()
139
-
140
- embedding = self.encoder.embed_utterance(audio_float)
141
-
142
- return embedding
143
- except Exception as e:
144
- logger.error(f"Embedding extraction error: {e}")
145
- return np.zeros(self.encoder.embedding_dim)
146
-
147
-
148
  class SpeakerChangeDetector:
149
- """Speaker change detector that supports a configurable number of speakers"""
150
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
151
  self.embedding_dim = embedding_dim
152
  self.change_threshold = change_threshold
@@ -254,569 +193,317 @@ class SpeakerChangeDetector:
254
  )
255
 
256
  return self.current_speaker, similarity
257
-
258
- def get_color_for_speaker(self, speaker_id):
259
- """Return color for speaker ID"""
260
- if 0 <= speaker_id < len(SPEAKER_COLORS):
261
- return SPEAKER_COLORS[speaker_id]
262
- return "#FFFFFF"
263
-
264
- def get_status_info(self):
265
- """Return status information about the speaker change detector"""
266
- speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
267
-
268
- return {
269
- "current_speaker": self.current_speaker,
270
- "speaker_counts": speaker_counts,
271
- "active_speakers": len(self.active_speakers),
272
- "max_speakers": self.max_speakers,
273
- "last_similarity": self.last_similarity,
274
- "threshold": self.change_threshold
275
- }
276
 
277
 
278
- class WhisperTranscriber:
279
- """Whisper transcriber using transformers with FastRTC optimization"""
280
- def __init__(self, model_name="distil-large-v3"):
281
- self.model = None
282
- self.processor = None
283
- self.model_name = model_name
284
- self.model_loaded = False
285
-
286
- def load_model(self):
287
- """Load Whisper model"""
288
- try:
289
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
290
-
291
- model_id = f"distil-whisper/distil-{self.model_name}" if "distil" in self.model_name else f"openai/whisper-{self.model_name}"
292
-
293
- self.processor = WhisperProcessor.from_pretrained(model_id)
294
- self.model = WhisperForConditionalGeneration.from_pretrained(
295
- model_id,
296
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
297
- low_cpu_mem_usage=True,
298
- use_safetensors=True
299
- )
300
-
301
- if torch.cuda.is_available():
302
- self.model = self.model.cuda()
303
-
304
- self.model_loaded = True
305
- return True
306
- except Exception as e:
307
- logger.error(f"Error loading Whisper model: {e}")
308
- return False
309
 
310
- def transcribe(self, audio_array, sample_rate=16000):
311
- """Transcribe audio array"""
312
- if not self.model_loaded:
313
- return ""
314
-
315
  try:
316
- # Ensure audio is the right length and format
317
- if len(audio_array) < 1600: # Less than 0.1 seconds
318
- return ""
319
-
320
- # Resample if needed
321
- if sample_rate != 16000:
322
- import torchaudio.functional as F
323
- audio_tensor = torch.tensor(audio_array, dtype=torch.float32)
324
- audio_array = F.resample(audio_tensor, sample_rate, 16000).numpy()
325
-
326
- # Process with Whisper
327
- inputs = self.processor(
328
- audio_array,
329
- sampling_rate=16000,
330
- return_tensors="pt",
331
- truncation=False,
332
- padding=True
333
- )
334
 
335
- if torch.cuda.is_available():
336
- inputs = {k: v.cuda() for k, v in inputs.items()}
 
337
 
338
- with torch.no_grad():
339
- predicted_ids = self.model.generate(
340
- inputs["input_features"],
341
- max_length=448,
342
- num_beams=1,
343
- do_sample=False,
344
- use_cache=True
345
- )
346
-
347
- transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
348
 
349
- return transcription.strip()
350
  except Exception as e:
351
- logger.error(f"Transcription error: {e}")
352
- return ""
353
 
354
 
355
- class FastRTCSpeakerDiarization:
356
- def __init__(self):
 
357
  self.encoder = None
358
  self.audio_processor = None
359
  self.speaker_detector = None
360
- self.transcriber = None
361
- self.audio_queue = queue.Queue(maxsize=100)
 
 
 
 
 
362
  self.processing_thread = None
363
- self.full_sentences = []
364
- self.sentence_speakers = []
365
- self.is_running = False
366
- self.change_threshold = DEFAULT_CHANGE_THRESHOLD
367
- self.max_speakers = DEFAULT_MAX_SPEAKERS
368
- self.audio_buffer = []
369
- self.buffer_duration = 3.0 # seconds
370
- self.last_transcription_time = time.time()
371
- self.chunk_size = int(SAMPLE_RATE * CHUNK_DURATION_MS / 1000)
372
-
373
- def initialize_models(self):
374
- """Initialize the speaker encoder and transcription models"""
375
  try:
376
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
377
- logger.info(f"Using device: {device_str}")
378
 
379
- # Initialize speaker encoder
380
  self.encoder = SpeechBrainEncoder(device=device_str)
381
- encoder_success = self.encoder.load_model()
382
-
383
- # Initialize transcriber
384
- self.transcriber = WhisperTranscriber(FINAL_TRANSCRIPTION_MODEL)
385
- transcriber_success = self.transcriber.load_model()
386
 
387
- if encoder_success and transcriber_success:
388
- self.audio_processor = AudioProcessor(self.encoder)
389
- self.speaker_detector = SpeakerChangeDetector(
390
- embedding_dim=self.encoder.embedding_dim,
391
- change_threshold=self.change_threshold,
392
- max_speakers=self.max_speakers
393
- )
394
- logger.info("Models loaded successfully!")
395
- return True
396
- else:
397
- logger.error("Failed to load models")
398
  return False
 
 
 
 
 
 
 
 
 
 
 
 
399
  except Exception as e:
400
- logger.error(f"Model initialization error: {e}")
401
  return False
402
 
403
- def process_audio_chunk(self, audio_chunk: np.ndarray, sample_rate: int):
404
- """Process individual audio chunk from FastRTC"""
405
- if not self.is_running or audio_chunk is None:
406
- return
407
 
408
- try:
409
- # Ensure audio chunk is in correct format
410
- if isinstance(audio_chunk, np.ndarray):
411
- # Ensure mono audio
412
- if len(audio_chunk.shape) > 1:
413
- audio_chunk = audio_chunk.mean(axis=1)
414
-
415
- # Normalize audio
416
- if audio_chunk.dtype != np.float32:
417
- audio_chunk = audio_chunk.astype(np.float32)
418
-
419
- if np.abs(audio_chunk).max() > 1.0:
420
- audio_chunk = audio_chunk / np.abs(audio_chunk).max()
421
-
422
- # Add to buffer
423
- self.audio_buffer.extend(audio_chunk)
424
-
425
- # Keep buffer to specified duration
426
- max_buffer_length = int(self.buffer_duration * sample_rate)
427
- if len(self.audio_buffer) > max_buffer_length:
428
- self.audio_buffer = self.audio_buffer[-max_buffer_length:]
429
-
430
- # Process if enough audio accumulated and enough time passed
431
- current_time = time.time()
432
- if (current_time - self.last_transcription_time > 1.5 and
433
- len(self.audio_buffer) > sample_rate * 0.8): # At least 0.8 seconds
434
-
435
- if not self.audio_queue.full():
436
- self.audio_queue.put((np.array(self.audio_buffer[-int(sample_rate * 2):]), sample_rate))
437
- self.last_transcription_time = current_time
438
-
439
- except Exception as e:
440
- logger.error(f"Audio chunk processing error: {e}")
441
-
442
- def process_audio_queue(self):
443
- """Process audio from the queue"""
444
- while self.is_running:
445
- try:
446
- audio_data, sample_rate = self.audio_queue.get(timeout=1)
447
-
448
- if len(audio_data) < 1600: # Skip very short audio
449
- continue
450
-
451
- # Transcribe audio
452
- transcription = self.transcriber.transcribe(audio_data, sample_rate)
453
-
454
- if transcription and len(transcription.strip()) > 0:
455
- # Extract speaker embedding
456
- speaker_embedding = self.audio_processor.extract_embedding(audio_data)
457
-
458
- # Detect speaker
459
- speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
460
-
461
- # Store results
462
- self.full_sentences.append(transcription.strip())
463
- self.sentence_speakers.append(speaker_id)
464
-
465
- logger.info(f"Processed: Speaker {speaker_id + 1}: {transcription.strip()[:50]}...")
466
-
467
- except queue.Empty:
468
- continue
469
- except Exception as e:
470
- logger.error(f"Error processing audio queue: {e}")
471
 
472
- def start_recording(self):
473
- """Start the recording and processing"""
474
- if self.encoder is None or self.transcriber is None:
475
- return "Please initialize models first!"
476
-
477
  try:
478
- self.is_running = True
479
- self.audio_buffer = []
480
- self.last_transcription_time = time.time()
481
 
482
- # Clear the queue
483
- while not self.audio_queue.empty():
484
- try:
485
- self.audio_queue.get_nowait()
486
- except queue.Empty:
487
- break
488
 
489
- # Start processing thread
490
- self.processing_thread = threading.Thread(target=self.process_audio_queue, daemon=True)
491
- self.processing_thread.start()
492
 
493
- logger.info("Recording started successfully!")
494
- return "Recording started successfully!"
495
 
496
  except Exception as e:
497
- logger.error(f"Error starting recording: {e}")
498
- return f"Error starting recording: {e}"
499
 
500
- def stop_recording(self):
501
- """Stop the recording process"""
502
- self.is_running = False
503
- logger.info("Recording stopped!")
504
- return "Recording stopped!"
505
 
506
- def clear_conversation(self):
507
- """Clear all conversation data"""
508
- self.full_sentences = []
509
- self.sentence_speakers = []
510
- self.audio_buffer = []
511
-
512
- # Clear the queue
513
- while not self.audio_queue.empty():
514
- try:
515
- self.audio_queue.get_nowait()
516
- except queue.Empty:
517
- break
518
 
 
 
 
 
 
 
 
519
  if self.speaker_detector:
520
  self.speaker_detector = SpeakerChangeDetector(
521
  embedding_dim=self.encoder.embedding_dim,
522
  change_threshold=self.change_threshold,
523
  max_speakers=self.max_speakers
524
  )
525
-
526
- return "Conversation cleared!"
527
-
528
- def update_settings(self, threshold, max_speakers):
529
- """Update speaker detection settings"""
530
- self.change_threshold = threshold
531
- self.max_speakers = max_speakers
532
-
533
- if self.speaker_detector:
534
- self.speaker_detector.set_change_threshold(threshold)
535
- self.speaker_detector.set_max_speakers(max_speakers)
536
-
537
- return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
538
-
539
- def get_formatted_conversation(self):
540
- """Get the formatted conversation with speaker colors"""
541
- try:
542
- if not self.full_sentences:
543
- return "Waiting for speech input... 🎤"
544
-
545
- sentences_with_style = []
546
-
547
- for i, sentence in enumerate(self.full_sentences[-10:]): # Show last 10 sentences
548
- if i >= len(self.sentence_speakers):
549
- color = "#FFFFFF"
550
- speaker_name = "Unknown"
551
- else:
552
- speaker_id = self.sentence_speakers[-(10-i) if len(self.sentence_speakers) >= 10 else i]
553
- color = self.speaker_detector.get_color_for_speaker(speaker_id)
554
- speaker_name = f"Speaker {speaker_id + 1}"
555
-
556
- sentences_with_style.append(
557
- f'<p><span style="color:{color}; font-weight: bold;">{speaker_name}:</span> {sentence}</p>')
558
-
559
- return "".join(sentences_with_style)
560
-
561
- except Exception as e:
562
- return f"Error formatting conversation: {e}"
563
-
564
- def get_status_info(self):
565
- """Get current status information"""
566
- if not self.speaker_detector:
567
- return "Speaker detector not initialized"
568
-
569
- try:
570
- status = self.speaker_detector.get_status_info()
571
- queue_size = self.audio_queue.qsize()
572
-
573
- status_lines = [
574
- f"**Current Speaker:** {status['current_speaker'] + 1}",
575
- f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
576
- f"**Last Similarity:** {status['last_similarity']:.3f}",
577
- f"**Change Threshold:** {status['threshold']:.2f}",
578
- f"**Total Sentences:** {len(self.full_sentences)}",
579
- f"**Buffer Length:** {len(self.audio_buffer)} samples",
580
- f"**Queue Size:** {queue_size}",
581
- "",
582
- "**Speaker Segment Counts:**"
583
- ]
584
-
585
- for i in range(status['max_speakers']):
586
- color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
587
- status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}")
588
-
589
- return "\n".join(status_lines)
590
-
591
- except Exception as e:
592
- return f"Error getting status: {e}"
593
 
594
 
595
  # Global instance
596
- diarization_system = FastRTCSpeakerDiarization()
597
 
598
 
599
- def initialize_system():
600
  """Initialize the diarization system"""
601
- success = diarization_system.initialize_models()
602
  if success:
603
- return "✅ System initialized successfully! Models loaded."
604
  else:
605
- return "❌ Failed to initialize system. Please check the logs."
606
-
607
 
608
- def start_recording():
609
- """Start recording and transcription"""
610
- return diarization_system.start_recording()
611
 
612
-
613
- def stop_recording():
614
- """Stop recording and transcription"""
615
- return diarization_system.stop_recording()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
 
618
  def clear_conversation():
619
- """Clear the conversation"""
620
- return diarization_system.clear_conversation()
621
-
622
-
623
- def update_settings(threshold, max_speakers):
624
- """Update system settings"""
625
- return diarization_system.update_settings(threshold, max_speakers)
626
-
627
-
628
- def get_conversation():
629
- """Get the current conversation"""
630
- return diarization_system.get_formatted_conversation()
631
 
632
 
633
- def get_status():
634
- """Get system status"""
635
- return diarization_system.get_status_info()
636
-
637
-
638
- def process_audio_stream(audio_stream):
639
- """Process streaming audio from FastRTC"""
640
- if audio_stream is not None and diarization_system.is_running:
641
- sample_rate, audio_data = audio_stream
642
- diarization_system.process_audio_chunk(audio_data, sample_rate)
643
-
644
- return get_conversation(), get_status()
645
-
646
-
647
- # Create Gradio interface with FastRTC
648
- def create_interface():
649
- with gr.Blocks(title="FastRTC Real-time Speaker Diarization", theme=gr.themes.Soft()) as app:
650
- gr.Markdown("# 🎤 FastRTC Real-time Speech Recognition with Speaker Diarization")
651
- gr.Markdown("This app uses Hugging Face FastRTC for real-time audio streaming with automatic speaker identification and color-coding.")
652
 
 
653
  with gr.Row():
654
- with gr.Column(scale=2):
655
- # FastRTC Audio input for real-time streaming
656
- audio_input = gr.Audio(
657
- sources=["microphone"],
658
- type="numpy",
659
- streaming=True,
660
- label="🎙️ FastRTC Microphone Input",
661
- format="wav",
662
- show_download_button=False,
663
- container=True,
664
- elem_id="fastrtc_audio"
665
- )
666
-
667
- # Main conversation display
668
- conversation_output = gr.HTML(
669
- value="<i>Click 'Initialize System' and then 'Start Recording' to begin...</i>",
670
- label="Live Conversation",
671
- elem_id="conversation_display"
672
- )
673
-
674
- # Control buttons
675
- with gr.Row():
676
- init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
677
- start_btn = gr.Button("🎙️ Start Recording", variant="primary", interactive=False, size="lg")
678
- stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", interactive=False, size="lg")
679
- clear_btn = gr.Button("🗑️ Clear", interactive=False, size="lg")
680
-
681
- # Status display
682
- status_output = gr.Textbox(
683
- label="System Status",
684
- value="System not initialized",
685
- lines=10,
686
- interactive=False,
687
- show_copy_button=True
688
- )
689
-
690
- with gr.Column(scale=1):
691
- # Settings panel
692
- gr.Markdown("## ⚙️ Settings")
693
-
694
- threshold_slider = gr.Slider(
695
- minimum=0.1,
696
- maximum=0.95,
697
- step=0.05,
698
  value=DEFAULT_CHANGE_THRESHOLD,
699
- label="Speaker Change Sensitivity",
700
- info="Lower = more sensitive to changes"
 
701
  )
702
-
703
- max_speakers_slider = gr.Slider(
704
  minimum=2,
705
  maximum=ABSOLUTE_MAX_SPEAKERS,
706
- step=1,
707
  value=DEFAULT_MAX_SPEAKERS,
708
- label="Maximum Number of Speakers"
 
 
709
  )
710
-
711
- update_settings_btn = gr.Button("Update Settings", variant="secondary")
712
-
713
- # Speaker color legend
714
- gr.Markdown("## 🎨 Speaker Colors")
715
- color_info = []
716
- for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)):
717
- color_info.append(f'<span style="color:{color}; font-size: 16px;">●</span> Speaker {i+1} ({name})')
718
-
719
- gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
720
-
721
- # Performance info
722
- gr.Markdown("## 📊 Performance")
723
- gr.Markdown("""
724
- - **FastRTC**: Low-latency audio streaming
725
- - **Whisper**: distil-large-v3 for transcription
726
- - **ECAPA-TDNN**: Speaker embeddings
727
- - **Real-time**: ~100ms processing chunks
728
- """)
729
 
730
- # Event handlers
731
- def on_initialize():
732
- result = initialize_system()
733
- if "successfully" in result:
734
- return (
735
- result, # status_output
736
- gr.update(interactive=True), # start_btn
737
- gr.update(interactive=True), # clear_btn
738
- get_conversation(), # conversation_output
739
- get_status() # status_output update
740
  )
741
- else:
742
- return (
743
- result, # status_output
744
- gr.update(interactive=False), # start_btn
745
- gr.update(interactive=False), # clear_btn
746
- get_conversation(), # conversation_output
747
- get_status() # status_output update
748
  )
749
-
750
- def on_start():
751
- result = start_recording()
752
- return (
753
- result, # status_output
754
- gr.update(interactive=False), # start_btn
755
- gr.update(interactive=True), # stop_btn
756
- )
757
-
758
- def on_stop():
759
- result = stop_recording()
760
- return (
761
- result, # status_output
762
- gr.update(interactive=True), # start_btn
763
- gr.update(interactive=False), # stop_btn
764
- )
765
-
766
- # Auto-refresh function
767
- def refresh_display():
768
- return get_conversation(), get_status()
769
-
770
- # Connect event handlers
771
- init_btn.click(
772
- on_initialize,
773
- outputs=[status_output, start_btn, clear_btn, conversation_output, status_output]
774
  )
775
 
776
- start_btn.click(
777
- on_start,
778
- outputs=[status_output, start_btn, stop_btn]
 
779
  )
780
 
781
- stop_btn.click(
782
- on_stop,
783
- outputs=[status_output, start_btn, stop_btn]
 
 
 
 
 
 
 
784
  )
785
 
786
  clear_btn.click(
787
- clear_conversation,
788
- outputs=[status_output]
789
  )
790
 
791
- update_settings_btn.click(
792
- update_settings,
793
- inputs=[threshold_slider, max_speakers_slider],
794
- outputs=[status_output]
 
 
 
 
 
 
 
795
  )
796
 
797
- # FastRTC streaming audio processing
798
- audio_input.stream(
799
- process_audio_stream,
800
- inputs=[audio_input],
801
- outputs=[conversation_output, status_output],
802
- stream_every=0.1, # Process every 100ms
803
- time_limit=None
804
- )
805
-
806
- # Auto-refresh timer
807
- refresh_timer = gr.Timer(2.0)
808
- refresh_timer.tick(
809
- refresh_display,
810
- outputs=[conversation_output, status_output]
811
- )
 
 
 
812
 
813
- return app
814
 
815
 
816
  if __name__ == "__main__":
817
- app = create_interface()
818
- app.launch(
 
 
819
  server_name="0.0.0.0",
820
  server_port=7860,
821
- share=True
822
  )
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import torch
4
+ import torchaudio
5
  import time
 
6
  import os
7
  import urllib.request
 
8
  from scipy.spatial.distance import cosine
9
+ import threading
10
+ import queue
11
+ from collections import deque
12
  import asyncio
13
+ from typing import Generator, Tuple, List, Optional
 
14
 
15
+ # Configuration parameters (keeping original models)
 
 
 
 
 
16
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
17
  FINAL_BEAM_SIZE = 5
18
  REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
 
29
  MIN_SEGMENT_DURATION = 1.0
30
  DEFAULT_MAX_SPEAKERS = 4
31
  ABSOLUTE_MAX_SPEAKERS = 10
 
 
 
32
  SAMPLE_RATE = 16000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Speaker labels
35
+ SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)]
36
 
37
  class SpeechBrainEncoder:
38
  """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
 
44
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
45
  os.makedirs(self.cache_dir, exist_ok=True)
46
 
 
 
 
 
 
 
 
 
 
 
 
47
  def load_model(self):
48
  """Load the ECAPA-TDNN model"""
49
  try:
50
  from speechbrain.pretrained import EncoderClassifier
51
 
 
 
52
  self.model = EncoderClassifier.from_hparams(
53
  source="speechbrain/spkrec-ecapa-voxceleb",
54
  savedir=self.cache_dir,
 
58
  self.model_loaded = True
59
  return True
60
  except Exception as e:
61
+ print(f"Error loading ECAPA-TDNN model: {e}")
62
  return False
63
 
64
  def embed_utterance(self, audio, sr=16000):
 
80
 
81
  return embedding.squeeze().cpu().numpy()
82
  except Exception as e:
83
+ print(f"Error extracting embedding: {e}")
84
  return np.zeros(self.embedding_dim)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  class SpeakerChangeDetector:
88
+ """Speaker change detector that supports configurable number of speakers"""
89
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
90
  self.embedding_dim = embedding_dim
91
  self.change_threshold = change_threshold
 
193
  )
194
 
195
  return self.current_speaker, similarity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
 
198
+ class AudioProcessor:
199
+ """Processes audio data to extract speaker embeddings"""
200
+ def __init__(self, encoder):
201
+ self.encoder = encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def extract_embedding(self, audio_data):
 
 
 
 
204
  try:
205
+ # Ensure audio is float32 and normalized
206
+ if audio_data.dtype != np.float32:
207
+ audio_data = audio_data.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ # Normalize if needed
210
+ if np.abs(audio_data).max() > 1.0:
211
+ audio_data = audio_data / np.abs(audio_data).max()
212
 
213
+ # Extract embedding using the loaded encoder
214
+ embedding = self.encoder.embed_utterance(audio_data)
 
 
 
 
 
 
 
 
215
 
216
+ return embedding
217
  except Exception as e:
218
+ print(f"Embedding extraction error: {e}")
219
+ return np.zeros(self.encoder.embedding_dim)
220
 
221
 
222
+ class RealTimeSpeakerDiarization:
223
+ """Main class for real-time speaker diarization"""
224
+ def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
225
  self.encoder = None
226
  self.audio_processor = None
227
  self.speaker_detector = None
228
+ self.change_threshold = change_threshold
229
+ self.max_speakers = max_speakers
230
+ self.transcript_history = []
231
+ self.is_initialized = False
232
+
233
+ # Threading components
234
+ self.audio_queue = queue.Queue()
235
  self.processing_thread = None
236
+ self.running = False
237
+
238
+ async def initialize(self):
239
+ """Initialize the speaker diarization system"""
240
+ if self.is_initialized:
241
+ return True
242
+
 
 
 
 
 
243
  try:
244
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
245
+ print(f"Initializing ECAPA-TDNN model on {device_str}...")
246
 
 
247
  self.encoder = SpeechBrainEncoder(device=device_str)
248
+ success = self.encoder.load_model()
 
 
 
 
249
 
250
+ if not success:
 
 
 
 
 
 
 
 
 
 
251
  return False
252
+
253
+ self.audio_processor = AudioProcessor(self.encoder)
254
+ self.speaker_detector = SpeakerChangeDetector(
255
+ embedding_dim=self.encoder.embedding_dim,
256
+ change_threshold=self.change_threshold,
257
+ max_speakers=self.max_speakers
258
+ )
259
+
260
+ self.is_initialized = True
261
+ print("Speaker diarization system initialized successfully!")
262
+ return True
263
+
264
  except Exception as e:
265
+ print(f"Initialization error: {e}")
266
  return False
267
 
268
+ def update_settings(self, change_threshold, max_speakers):
269
+ """Update diarization settings"""
270
+ self.change_threshold = change_threshold
271
+ self.max_speakers = max_speakers
272
 
273
+ if self.speaker_detector:
274
+ self.speaker_detector.set_change_threshold(change_threshold)
275
+ self.speaker_detector.set_max_speakers(max_speakers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ def process_audio_segment(self, audio_data: np.ndarray, text: str) -> Tuple[int, str]:
278
+ """Process an audio segment and return speaker ID and formatted text"""
279
+ if not self.is_initialized:
280
+ return 0, text
281
+
282
  try:
283
+ # Extract speaker embedding
284
+ embedding = self.audio_processor.extract_embedding(audio_data)
 
285
 
286
+ # Detect speaker
287
+ speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
 
 
 
 
288
 
289
+ # Format text with speaker label
290
+ speaker_label = SPEAKER_LABELS[speaker_id]
291
+ formatted_text = f"{speaker_label}: {text}"
292
 
293
+ return speaker_id, formatted_text
 
294
 
295
  except Exception as e:
296
+ print(f"Error processing audio segment: {e}")
297
+ return 0, f"Speaker 1: {text}"
298
 
299
+ def get_transcript_history(self):
300
+ """Get the formatted transcript history"""
301
+ return "\n".join(self.transcript_history)
 
 
302
 
303
+ def add_to_transcript(self, formatted_text: str):
304
+ """Add formatted text to transcript history"""
305
+ self.transcript_history.append(formatted_text)
 
 
 
 
 
 
 
 
 
306
 
307
+ # Keep only last 50 entries to prevent memory issues
308
+ if len(self.transcript_history) > 50:
309
+ self.transcript_history = self.transcript_history[-50:]
310
+
311
+ def clear_transcript(self):
312
+ """Clear transcript history and reset speaker detector"""
313
+ self.transcript_history = []
314
  if self.speaker_detector:
315
  self.speaker_detector = SpeakerChangeDetector(
316
  embedding_dim=self.encoder.embedding_dim,
317
  change_threshold=self.change_threshold,
318
  max_speakers=self.max_speakers
319
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
 
322
  # Global instance
323
+ diarization_system = RealTimeSpeakerDiarization()
324
 
325
 
326
+ async def initialize_system():
327
  """Initialize the diarization system"""
328
+ success = await diarization_system.initialize()
329
  if success:
330
+ return "✅ Speaker diarization system initialized successfully!"
331
  else:
332
+ return "❌ Failed to initialize speaker diarization system. Please check your setup."
 
333
 
 
 
 
334
 
335
+ def process_audio_with_transcript(audio_data, sample_rate, transcription_text, change_threshold, max_speakers):
336
+ """Process audio with transcription for speaker diarization"""
337
+ if not diarization_system.is_initialized:
338
+ return "Please initialize the system first.", ""
339
+
340
+ if audio_data is None or transcription_text.strip() == "":
341
+ return diarization_system.get_transcript_history(), ""
342
+
343
+ try:
344
+ # Update settings
345
+ diarization_system.update_settings(change_threshold, max_speakers)
346
+
347
+ # Convert audio to the right format
348
+ if len(audio_data.shape) > 1:
349
+ audio_data = audio_data.mean(axis=1) # Convert to mono
350
+
351
+ # Resample if needed
352
+ if sample_rate != SAMPLE_RATE:
353
+ audio_data = torchaudio.functional.resample(
354
+ torch.tensor(audio_data), sample_rate, SAMPLE_RATE
355
+ ).numpy()
356
+
357
+ # Process the audio segment
358
+ speaker_id, formatted_text = diarization_system.process_audio_segment(audio_data, transcription_text)
359
+
360
+ # Add to transcript
361
+ diarization_system.add_to_transcript(formatted_text)
362
+
363
+ # Return updated transcript and current speaker info
364
+ transcript = diarization_system.get_transcript_history()
365
+ current_speaker_info = f"Current Speaker: {SPEAKER_LABELS[speaker_id]}"
366
+
367
+ return transcript, current_speaker_info
368
+
369
+ except Exception as e:
370
+ error_msg = f"Error processing audio: {str(e)}"
371
+ return diarization_system.get_transcript_history(), error_msg
372
 
373
 
374
  def clear_conversation():
375
+ """Clear the conversation transcript"""
376
+ diarization_system.clear_transcript()
377
+ return "", "Conversation cleared."
 
 
 
 
 
 
 
 
 
378
 
379
 
380
+ def create_gradio_interface():
381
+ """Create and return the Gradio interface"""
382
+ with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
383
+ gr.Markdown("# 🎙️ Real-time Speaker Diarization with ASR")
384
+ gr.Markdown("Upload audio with transcription to perform real-time speaker diarization.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
+ # Initialization section
387
  with gr.Row():
388
+ init_btn = gr.Button("🚀 Initialize System", variant="primary")
389
+ init_status = gr.Textbox(label="Initialization Status", interactive=False)
390
+
391
+ # Settings section
392
+ with gr.Row():
393
+ with gr.Column():
394
+ change_threshold = gr.Slider(
395
+ minimum=0.1,
396
+ maximum=0.9,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  value=DEFAULT_CHANGE_THRESHOLD,
398
+ step=0.05,
399
+ label="Speaker Change Threshold",
400
+ info="Lower values = more sensitive to speaker changes"
401
  )
402
+ with gr.Column():
403
+ max_speakers = gr.Slider(
404
  minimum=2,
405
  maximum=ABSOLUTE_MAX_SPEAKERS,
 
406
  value=DEFAULT_MAX_SPEAKERS,
407
+ step=1,
408
+ label="Maximum Number of Speakers",
409
+ info="Maximum number of speakers to detect"
410
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
+ # Audio input and transcription
413
+ with gr.Row():
414
+ with gr.Column():
415
+ audio_input = gr.Audio(
416
+ label="Audio Input",
417
+ type="numpy",
418
+ format="wav"
 
 
 
419
  )
420
+ transcription_input = gr.Textbox(
421
+ label="Transcription Text",
422
+ placeholder="Enter the transcription of the audio...",
423
+ lines=3
 
 
 
424
  )
425
+ process_btn = gr.Button("🎯 Process Audio", variant="secondary")
426
+
427
+ with gr.Column():
428
+ current_speaker = gr.Textbox(
429
+ label="Current Speaker",
430
+ interactive=False
431
+ )
432
+ clear_btn = gr.Button("🗑️ Clear Conversation", variant="stop")
433
+
434
+ # Output section
435
+ transcript_output = gr.Textbox(
436
+ label="Live Transcript with Speaker Labels",
437
+ lines=15,
438
+ max_lines=20,
439
+ interactive=False,
440
+ placeholder="Processed transcript will appear here..."
 
 
 
 
 
 
 
 
 
441
  )
442
 
443
+ # Event handlers
444
+ init_btn.click(
445
+ fn=initialize_system,
446
+ outputs=[init_status]
447
  )
448
 
449
+ process_btn.click(
450
+ fn=process_audio_with_transcript,
451
+ inputs=[
452
+ audio_input,
453
+ gr.Number(value=SAMPLE_RATE, visible=False), # Hidden sample rate
454
+ transcription_input,
455
+ change_threshold,
456
+ max_speakers
457
+ ],
458
+ outputs=[transcript_output, current_speaker]
459
  )
460
 
461
  clear_btn.click(
462
+ fn=clear_conversation,
463
+ outputs=[transcript_output, current_speaker]
464
  )
465
 
466
+ # Auto-process when audio and transcription are provided
467
+ audio_input.change(
468
+ fn=process_audio_with_transcript,
469
+ inputs=[
470
+ audio_input,
471
+ gr.Number(value=SAMPLE_RATE, visible=False),
472
+ transcription_input,
473
+ change_threshold,
474
+ max_speakers
475
+ ],
476
+ outputs=[transcript_output, current_speaker]
477
  )
478
 
479
+ # Instructions
480
+ gr.Markdown("""
481
+ ## Instructions:
482
+ 1. **Initialize**: Click "Initialize System" to load the speaker diarization models
483
+ 2. **Upload Audio**: Upload an audio file (WAV format recommended)
484
+ 3. **Add Transcription**: Enter the transcription text for the audio
485
+ 4. **Adjust Settings**:
486
+ - **Speaker Change Threshold**: Lower values detect speaker changes more easily
487
+ - **Max Speakers**: Set the maximum number of speakers you expect
488
+ 5. **Process**: Click "Process Audio" or the system will auto-process
489
+ 6. **View Results**: See the transcript with speaker labels (Speaker 1, Speaker 2, etc.)
490
+
491
+ ## Tips:
492
+ - For similar-sounding speakers, increase the threshold (0.6-0.8)
493
+ - For different-sounding speakers, lower threshold works better (0.3-0.5)
494
+ - The system maintains speaker consistency across the conversation
495
+ - Use "Clear Conversation" to reset the speaker memory
496
+ """)
497
 
498
+ return demo
499
 
500
 
501
  if __name__ == "__main__":
502
+ # Create and launch the Gradio interface
503
+ demo = create_gradio_interface()
504
+ demo.launch(
505
+ share=True,
506
  server_name="0.0.0.0",
507
  server_port=7860,
508
+ show_error=True
509
  )