Saiyaswanth007 commited on
Commit
e1bfb0a
·
1 Parent(s): 68a4a19

Intial config

Browse files
Files changed (1) hide show
  1. app.py +3 -971
app.py CHANGED
@@ -1,975 +1,7 @@
1
- import gradio as gr
2
- import numpy as np
3
- import queue
4
- import torch
5
- import time
6
- import threading
7
- import os
8
- import urllib.request
9
- import torchaudio
10
- from scipy.spatial.distance import cosine
11
- from scipy.signal import resample
12
- from RealtimeSTT import AudioToTextRecorder
13
- from fastapi import FastAPI, APIRouter
14
- from fastrtc import Stream, AsyncStreamHandler
15
- import json
16
- import asyncio
17
- import uvicorn
18
- from queue import Queue
19
- import logging
20
 
21
- # Set up logging
22
- logging.basicConfig(level=logging.INFO)
23
- logger = logging.getLogger(__name__)
24
-
25
- # Simplified configuration parameters
26
- SILENCE_THRESHS = [0, 0.4]
27
- FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
28
- FINAL_BEAM_SIZE = 5
29
- REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
30
- REALTIME_BEAM_SIZE = 5
31
- TRANSCRIPTION_LANGUAGE = "en"
32
- SILERO_SENSITIVITY = 0.4
33
- WEBRTC_SENSITIVITY = 3
34
- MIN_LENGTH_OF_RECORDING = 0.7
35
- PRE_RECORDING_BUFFER_DURATION = 0.35
36
-
37
- # Speaker change detection parameters
38
- DEFAULT_CHANGE_THRESHOLD = 0.65
39
- EMBEDDING_HISTORY_SIZE = 5
40
- MIN_SEGMENT_DURATION = 1.5
41
- DEFAULT_MAX_SPEAKERS = 4
42
- ABSOLUTE_MAX_SPEAKERS = 8
43
-
44
- # Global variables
45
- SAMPLE_RATE = 16000
46
- BUFFER_SIZE = 1024
47
- CHANNELS = 1
48
-
49
- # Speaker colors - more distinguishable colors
50
- SPEAKER_COLORS = [
51
- "#FF6B6B", # Red
52
- "#4ECDC4", # Teal
53
- "#45B7D1", # Blue
54
- "#96CEB4", # Green
55
- "#FFEAA7", # Yellow
56
- "#DDA0DD", # Plum
57
- "#98D8C8", # Mint
58
- "#F7DC6F", # Gold
59
- ]
60
-
61
- SPEAKER_COLOR_NAMES = [
62
- "Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold"
63
- ]
64
-
65
-
66
- class SpeechBrainEncoder:
67
- """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
68
- def __init__(self, device="cpu"):
69
- self.device = device
70
- self.model = None
71
- self.embedding_dim = 192
72
- self.model_loaded = False
73
- self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
74
- os.makedirs(self.cache_dir, exist_ok=True)
75
-
76
- def load_model(self):
77
- """Load the ECAPA-TDNN model"""
78
- try:
79
- # Import SpeechBrain
80
- from speechbrain.pretrained import EncoderClassifier
81
-
82
- # Get model path
83
- model_path = self._download_model()
84
-
85
- # Load the pre-trained model
86
- self.model = EncoderClassifier.from_hparams(
87
- source="speechbrain/spkrec-ecapa-voxceleb",
88
- savedir=self.cache_dir,
89
- run_opts={"device": self.device}
90
- )
91
-
92
- self.model_loaded = True
93
- return True
94
- except Exception as e:
95
- print(f"Error loading ECAPA-TDNN model: {e}")
96
- return False
97
-
98
- def embed_utterance(self, audio, sr=16000):
99
- """Extract speaker embedding from audio"""
100
- if not self.model_loaded:
101
- raise ValueError("Model not loaded. Call load_model() first.")
102
-
103
- try:
104
- if isinstance(audio, np.ndarray):
105
- # Ensure audio is float32 and properly normalized
106
- audio = audio.astype(np.float32)
107
- if np.max(np.abs(audio)) > 1.0:
108
- audio = audio / np.max(np.abs(audio))
109
- waveform = torch.tensor(audio).unsqueeze(0)
110
- else:
111
- waveform = audio.unsqueeze(0)
112
-
113
- # Resample if necessary
114
- if sr != 16000:
115
- waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
116
-
117
- with torch.no_grad():
118
- embedding = self.model.encode_batch(waveform)
119
-
120
- return embedding.squeeze().cpu().numpy()
121
- except Exception as e:
122
- logger.error(f"Error extracting embedding: {e}")
123
- return np.zeros(self.embedding_dim)
124
-
125
-
126
- class AudioProcessor:
127
- """Processes audio data to extract speaker embeddings"""
128
- def __init__(self, encoder):
129
- self.encoder = encoder
130
- self.audio_buffer = []
131
- self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio
132
-
133
- def add_audio_chunk(self, audio_chunk):
134
- """Add audio chunk to buffer"""
135
- self.audio_buffer.extend(audio_chunk)
136
-
137
- # Keep buffer from getting too large
138
- max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max
139
- if len(self.audio_buffer) > max_buffer_size:
140
- self.audio_buffer = self.audio_buffer[-max_buffer_size:]
141
-
142
- def extract_embedding_from_buffer(self):
143
- """Extract embedding from current audio buffer"""
144
- if len(self.audio_buffer) < self.min_audio_length:
145
- return None
146
-
147
- try:
148
- # Use the last portion of the buffer for embedding
149
- audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32)
150
-
151
- # Normalize audio
152
- if np.max(np.abs(audio_segment)) > 0:
153
- audio_segment = audio_segment / np.max(np.abs(audio_segment))
154
- else:
155
- return None
156
-
157
- embedding = self.encoder.embed_utterance(audio_segment)
158
- return embedding
159
- except Exception as e:
160
- logger.error(f"Embedding extraction error: {e}")
161
- return None
162
-
163
-
164
- class SpeakerChangeDetector:
165
- """Improved speaker change detector"""
166
- def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
167
- self.embedding_dim = embedding_dim
168
- self.change_threshold = change_threshold
169
- self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
170
- self.current_speaker = 0
171
- self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
172
- self.speaker_centroids = [None] * self.max_speakers
173
- self.last_change_time = time.time()
174
- self.last_similarity = 1.0
175
- self.active_speakers = set([0])
176
- self.segment_counter = 0
177
-
178
- def set_max_speakers(self, max_speakers):
179
- """Update the maximum number of speakers"""
180
- new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
181
-
182
- if new_max < self.max_speakers:
183
- # Remove speakers beyond the new limit
184
- for speaker_id in list(self.active_speakers):
185
- if speaker_id >= new_max:
186
- self.active_speakers.discard(speaker_id)
187
-
188
- if self.current_speaker >= new_max:
189
- self.current_speaker = 0
190
-
191
- # Resize arrays
192
- if new_max > self.max_speakers:
193
- self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
194
- self.speaker_centroids.extend([None] * (new_max - self.max_speakers))
195
- else:
196
- self.speaker_embeddings = self.speaker_embeddings[:new_max]
197
- self.speaker_centroids = self.speaker_centroids[:new_max]
198
-
199
- self.max_speakers = new_max
200
-
201
- def set_change_threshold(self, threshold):
202
- """Update the threshold for detecting speaker changes"""
203
- self.change_threshold = max(0.1, min(threshold, 0.95))
204
-
205
- def add_embedding(self, embedding, timestamp=None):
206
- """Add a new embedding and detect speaker changes"""
207
- current_time = timestamp or time.time()
208
- self.segment_counter += 1
209
-
210
- # Initialize first speaker
211
- if not self.speaker_embeddings[0]:
212
- self.speaker_embeddings[0].append(embedding)
213
- self.speaker_centroids[0] = embedding.copy()
214
- self.active_speakers.add(0)
215
- return 0, 1.0
216
-
217
- # Calculate similarity with current speaker
218
- current_centroid = self.speaker_centroids[self.current_speaker]
219
- if current_centroid is not None:
220
- similarity = 1.0 - cosine(embedding, current_centroid)
221
- else:
222
- similarity = 0.5
223
-
224
- self.last_similarity = similarity
225
-
226
- # Check for speaker change
227
- time_since_last_change = current_time - self.last_change_time
228
- speaker_changed = False
229
-
230
- if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold:
231
- # Find best matching speaker
232
- best_speaker = self.current_speaker
233
- best_similarity = similarity
234
-
235
- for speaker_id in self.active_speakers:
236
- if speaker_id == self.current_speaker:
237
- continue
238
-
239
- centroid = self.speaker_centroids[speaker_id]
240
- if centroid is not None:
241
- speaker_similarity = 1.0 - cosine(embedding, centroid)
242
- if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold:
243
- best_similarity = speaker_similarity
244
- best_speaker = speaker_id
245
-
246
- # If no good match found and we can add a new speaker
247
- if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers:
248
- for new_id in range(self.max_speakers):
249
- if new_id not in self.active_speakers:
250
- best_speaker = new_id
251
- self.active_speakers.add(new_id)
252
- break
253
-
254
- if best_speaker != self.current_speaker:
255
- self.current_speaker = best_speaker
256
- self.last_change_time = current_time
257
- speaker_changed = True
258
-
259
- # Update speaker embeddings and centroids
260
- self.speaker_embeddings[self.current_speaker].append(embedding)
261
-
262
- # Keep only recent embeddings (sliding window)
263
- max_embeddings = 20
264
- if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings:
265
- self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:]
266
-
267
- # Update centroid
268
- if self.speaker_embeddings[self.current_speaker]:
269
- self.speaker_centroids[self.current_speaker] = np.mean(
270
- self.speaker_embeddings[self.current_speaker], axis=0
271
- )
272
-
273
- return self.current_speaker, similarity
274
-
275
- def get_color_for_speaker(self, speaker_id):
276
- """Return color for speaker ID"""
277
- if 0 <= speaker_id < len(SPEAKER_COLORS):
278
- return SPEAKER_COLORS[speaker_id]
279
- return "#FFFFFF"
280
-
281
- def get_status_info(self):
282
- """Return status information"""
283
- speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
284
-
285
- return {
286
- "current_speaker": self.current_speaker,
287
- "speaker_counts": speaker_counts,
288
- "active_speakers": len(self.active_speakers),
289
- "max_speakers": self.max_speakers,
290
- "last_similarity": self.last_similarity,
291
- "threshold": self.change_threshold,
292
- "segment_counter": self.segment_counter
293
- }
294
-
295
-
296
- class RealtimeSpeakerDiarization:
297
- def __init__(self):
298
- self.encoder = None
299
- self.audio_processor = None
300
- self.speaker_detector = None
301
- self.recorder = None
302
- self.sentence_queue = queue.Queue()
303
- self.full_sentences = []
304
- self.sentence_speakers = []
305
- self.pending_sentences = []
306
- self.current_conversation = ""
307
- self.is_running = False
308
- self.change_threshold = DEFAULT_CHANGE_THRESHOLD
309
- self.max_speakers = DEFAULT_MAX_SPEAKERS
310
- self.last_transcription = ""
311
- self.transcription_lock = threading.Lock()
312
-
313
- def initialize_models(self):
314
- """Initialize the speaker encoder model"""
315
- try:
316
- device_str = "cuda" if torch.cuda.is_available() else "cpu"
317
- logger.info(f"Using device: {device_str}")
318
-
319
- self.encoder = SpeechBrainEncoder(device=device_str)
320
- success = self.encoder.load_model()
321
-
322
- if success:
323
- self.audio_processor = AudioProcessor(self.encoder)
324
- self.speaker_detector = SpeakerChangeDetector(
325
- embedding_dim=self.encoder.embedding_dim,
326
- change_threshold=self.change_threshold,
327
- max_speakers=self.max_speakers
328
- )
329
- logger.info("Models initialized successfully!")
330
- return True
331
- else:
332
- logger.error("Failed to load models")
333
- return False
334
- except Exception as e:
335
- logger.error(f"Model initialization error: {e}")
336
- return False
337
-
338
- def live_text_detected(self, text):
339
- """Callback for real-time transcription updates"""
340
- with self.transcription_lock:
341
- self.last_transcription = text.strip()
342
-
343
- def process_final_text(self, text):
344
- """Process final transcribed text with speaker embedding"""
345
- text = text.strip()
346
- if text:
347
- try:
348
- # Get audio data for this transcription
349
- audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None)
350
- if audio_bytes:
351
- self.sentence_queue.put((text, audio_bytes))
352
- else:
353
- # If no audio bytes, use current speaker
354
- self.sentence_queue.put((text, None))
355
-
356
- except Exception as e:
357
- logger.error(f"Error processing final text: {e}")
358
-
359
- def process_sentence_queue(self):
360
- """Process sentences in the queue for speaker detection"""
361
- while self.is_running:
362
- try:
363
- text, audio_bytes = self.sentence_queue.get(timeout=1)
364
-
365
- current_speaker = self.speaker_detector.current_speaker
366
-
367
- if audio_bytes:
368
- # Convert audio data and extract embedding
369
- audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
370
- audio_float = audio_int16.astype(np.float32) / 32768.0
371
-
372
- # Extract embedding
373
- embedding = self.audio_processor.encoder.embed_utterance(audio_float)
374
- if embedding is not None:
375
- current_speaker, similarity = self.speaker_detector.add_embedding(embedding)
376
-
377
- # Store sentence with speaker
378
- with self.transcription_lock:
379
- self.full_sentences.append((text, current_speaker))
380
- self.update_conversation_display()
381
-
382
- except queue.Empty:
383
- continue
384
- except Exception as e:
385
- logger.error(f"Error processing sentence: {e}")
386
-
387
- def update_conversation_display(self):
388
- """Update the conversation display"""
389
- try:
390
- sentences_with_style = []
391
-
392
- for sentence_text, speaker_id in self.full_sentences:
393
- color = self.speaker_detector.get_color_for_speaker(speaker_id)
394
- speaker_name = f"Speaker {speaker_id + 1}"
395
- sentences_with_style.append(
396
- f'<span style="color:{color}; font-weight: bold;">{speaker_name}:</span> '
397
- f'<span style="color:#333333;">{sentence_text}</span>'
398
- )
399
-
400
- # Add current transcription if available
401
- if self.last_transcription:
402
- current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker)
403
- current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}"
404
- sentences_with_style.append(
405
- f'<span style="color:{current_color}; font-weight: bold; opacity: 0.7;">{current_speaker}:</span> '
406
- f'<span style="color:#666666; font-style: italic;">{self.last_transcription}...</span>'
407
- )
408
-
409
- if sentences_with_style:
410
- self.current_conversation = "<br><br>".join(sentences_with_style)
411
- else:
412
- self.current_conversation = "<i>Waiting for speech input...</i>"
413
-
414
- except Exception as e:
415
- logger.error(f"Error updating conversation display: {e}")
416
- self.current_conversation = f"<i>Error: {str(e)}</i>"
417
-
418
- def start_recording(self):
419
- """Start the recording and transcription process"""
420
- if self.encoder is None:
421
- return "Please initialize models first!"
422
-
423
- try:
424
- # Setup recorder configuration
425
- recorder_config = {
426
- 'spinner': False,
427
- 'use_microphone': False, # Using FastRTC for audio input
428
- 'model': FINAL_TRANSCRIPTION_MODEL,
429
- 'language': TRANSCRIPTION_LANGUAGE,
430
- 'silero_sensitivity': SILERO_SENSITIVITY,
431
- 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
432
- 'post_speech_silence_duration': SILENCE_THRESHS[1],
433
- 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
434
- 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
435
- 'min_gap_between_recordings': 0,
436
- 'enable_realtime_transcription': True,
437
- 'realtime_processing_pause': 0.1,
438
- 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
439
- 'on_realtime_transcription_update': self.live_text_detected,
440
- 'beam_size': FINAL_BEAM_SIZE,
441
- 'beam_size_realtime': REALTIME_BEAM_SIZE,
442
- 'sample_rate': SAMPLE_RATE,
443
- }
444
-
445
- self.recorder = AudioToTextRecorder(**recorder_config)
446
-
447
- # Start processing threads
448
- self.is_running = True
449
- self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
450
- self.sentence_thread.start()
451
-
452
- self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
453
- self.transcription_thread.start()
454
-
455
- return "Recording started successfully!"
456
-
457
- except Exception as e:
458
- logger.error(f"Error starting recording: {e}")
459
- return f"Error starting recording: {e}"
460
-
461
- def run_transcription(self):
462
- """Run the transcription loop"""
463
- try:
464
- logger.info("Starting transcription thread")
465
- while self.is_running:
466
- # Just check for final text from recorder, audio is fed externally via FastRTC
467
- text = self.recorder.text(self.process_final_text)
468
- time.sleep(0.01) # Small sleep to prevent CPU hogging
469
- except Exception as e:
470
- logger.error(f"Transcription error: {e}")
471
-
472
- def stop_recording(self):
473
- """Stop the recording process"""
474
- self.is_running = False
475
- if self.recorder:
476
- self.recorder.stop()
477
- return "Recording stopped!"
478
-
479
- def clear_conversation(self):
480
- """Clear all conversation data"""
481
- with self.transcription_lock:
482
- self.full_sentences = []
483
- self.last_transcription = ""
484
- self.current_conversation = "Conversation cleared!"
485
-
486
- if self.speaker_detector:
487
- self.speaker_detector = SpeakerChangeDetector(
488
- embedding_dim=self.encoder.embedding_dim,
489
- change_threshold=self.change_threshold,
490
- max_speakers=self.max_speakers
491
- )
492
-
493
- return "Conversation cleared!"
494
-
495
- def update_settings(self, threshold, max_speakers):
496
- """Update speaker detection settings"""
497
- self.change_threshold = threshold
498
- self.max_speakers = max_speakers
499
-
500
- if self.speaker_detector:
501
- self.speaker_detector.set_change_threshold(threshold)
502
- self.speaker_detector.set_max_speakers(max_speakers)
503
-
504
- return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
505
-
506
- def get_formatted_conversation(self):
507
- """Get the formatted conversation"""
508
- return self.current_conversation
509
-
510
- def get_status_info(self):
511
- """Get current status information"""
512
- if not self.speaker_detector:
513
- return "Speaker detector not initialized"
514
-
515
- try:
516
- status = self.speaker_detector.get_status_info()
517
-
518
- status_lines = [
519
- f"**Current Speaker:** {status['current_speaker'] + 1}",
520
- f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
521
- f"**Last Similarity:** {status['last_similarity']:.3f}",
522
- f"**Change Threshold:** {status['threshold']:.2f}",
523
- f"**Total Sentences:** {len(self.full_sentences)}",
524
- f"**Segments Processed:** {status['segment_counter']}",
525
- "",
526
- "**Speaker Activity:**"
527
- ]
528
-
529
- for i in range(status['max_speakers']):
530
- color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
531
- count = status['speaker_counts'][i]
532
- active = "🟢" if count > 0 else "⚫"
533
- status_lines.append(f"{active} Speaker {i+1} ({color_name}): {count} segments")
534
-
535
- return "\n".join(status_lines)
536
-
537
- except Exception as e:
538
- return f"Error getting status: {e}"
539
-
540
- def process_audio_chunk(self, audio_data, sample_rate=16000):
541
- """Process audio chunk from FastRTC input"""
542
- if not self.is_running or self.audio_processor is None:
543
- return
544
-
545
- try:
546
- # Ensure audio is float32
547
- if isinstance(audio_data, np.ndarray):
548
- if audio_data.dtype != np.float32:
549
- audio_data = audio_data.astype(np.float32)
550
- else:
551
- audio_data = np.array(audio_data, dtype=np.float32)
552
-
553
- # Ensure mono
554
- if len(audio_data.shape) > 1:
555
- audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten()
556
-
557
- # Normalize if needed
558
- if np.max(np.abs(audio_data)) > 1.0:
559
- audio_data = audio_data / np.max(np.abs(audio_data))
560
-
561
- # Add to audio processor buffer for speaker detection
562
- self.audio_processor.add_audio_chunk(audio_data)
563
-
564
- # Periodically extract embeddings for speaker detection
565
- if len(self.audio_processor.audio_buffer) % (SAMPLE_RATE // 2) == 0: # Every 0.5 seconds
566
- embedding = self.audio_processor.extract_embedding_from_buffer()
567
- if embedding is not None:
568
- self.speaker_detector.add_embedding(embedding)
569
-
570
- # Feed audio to RealtimeSTT recorder
571
- if self.recorder and self.is_running:
572
- # Convert float32 [-1.0, 1.0] to int16 for RealtimeSTT
573
- int16_data = (audio_data * 32768.0).astype(np.int16).tobytes()
574
- if sample_rate != 16000:
575
- int16_data = self.resample_audio(int16_data, sample_rate, 16000)
576
- self.recorder.feed_audio(int16_data)
577
-
578
- except Exception as e:
579
- logger.error(f"Error processing audio chunk: {e}")
580
-
581
- def resample_audio(self, audio_bytes, from_rate, to_rate):
582
- """Resample audio to target sample rate"""
583
- try:
584
- audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
585
- num_samples = len(audio_np)
586
- num_target_samples = int(num_samples * to_rate / from_rate)
587
-
588
- resampled = resample(audio_np, num_target_samples)
589
-
590
- return resampled.astype(np.int16).tobytes()
591
- except Exception as e:
592
- logger.error(f"Error resampling audio: {e}")
593
- return audio_bytes
594
-
595
-
596
- # FastRTC Audio Handler
597
- class DiarizationHandler(AsyncStreamHandler):
598
- def __init__(self, diarization_system):
599
- super().__init__()
600
- self.diarization_system = diarization_system
601
- self.audio_buffer = []
602
- self.buffer_size = BUFFER_SIZE
603
-
604
- def copy(self):
605
- """Return a fresh handler for each new stream connection"""
606
- return DiarizationHandler(self.diarization_system)
607
-
608
- async def emit(self):
609
- """Not used - we only receive audio"""
610
- return None
611
-
612
- async def receive(self, frame):
613
- """Receive audio data from FastRTC"""
614
- try:
615
- if not self.diarization_system.is_running:
616
- return
617
-
618
- # Extract audio data
619
- audio_data = getattr(frame, 'data', frame)
620
-
621
- # Convert to numpy array
622
- if isinstance(audio_data, bytes):
623
- audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
624
- elif isinstance(audio_data, (list, tuple)):
625
- sample_rate, audio_array = audio_data
626
- if isinstance(audio_array, (list, tuple)):
627
- audio_array = np.array(audio_array, dtype=np.float32)
628
- else:
629
- audio_array = np.array(audio_data, dtype=np.float32)
630
-
631
- # Ensure 1D
632
- if len(audio_array.shape) > 1:
633
- audio_array = audio_array.flatten()
634
-
635
- # Buffer audio chunks
636
- self.audio_buffer.extend(audio_array)
637
-
638
- # Process in chunks
639
- while len(self.audio_buffer) >= self.buffer_size:
640
- chunk = np.array(self.audio_buffer[:self.buffer_size])
641
- self.audio_buffer = self.audio_buffer[self.buffer_size:]
642
-
643
- # Process asynchronously
644
- await self.process_audio_async(chunk)
645
-
646
- except Exception as e:
647
- logger.error(f"Error in FastRTC receive: {e}")
648
-
649
- async def process_audio_async(self, audio_data):
650
- """Process audio data asynchronously"""
651
- try:
652
- loop = asyncio.get_event_loop()
653
- await loop.run_in_executor(
654
- None,
655
- self.diarization_system.process_audio_chunk,
656
- audio_data,
657
- SAMPLE_RATE
658
- )
659
- except Exception as e:
660
- logger.error(f"Error in async audio processing: {e}")
661
-
662
-
663
- # Global instances
664
- diarization_system = RealtimeSpeakerDiarization()
665
- audio_handler = None
666
-
667
- def initialize_system():
668
- """Initialize the diarization system"""
669
- global audio_handler
670
- try:
671
- success = diarization_system.initialize_models()
672
- if success:
673
- audio_handler = DiarizationHandler(diarization_system)
674
- return "✅ System initialized successfully!"
675
- else:
676
- return "❌ Failed to initialize system. Check logs for details."
677
- except Exception as e:
678
- logger.error(f"Initialization error: {e}")
679
- return f"❌ Initialization error: {str(e)}"
680
-
681
- def start_recording():
682
- """Start recording and transcription"""
683
- try:
684
- result = diarization_system.start_recording()
685
- return f"🎙️ {result}"
686
- except Exception as e:
687
- return f"❌ Failed to start recording: {str(e)}"
688
-
689
- def stop_recording():
690
- """Stop recording and transcription"""
691
- try:
692
- result = diarization_system.stop_recording()
693
- return f"⏹️ {result}"
694
- except Exception as e:
695
- return f"❌ Failed to stop recording: {str(e)}"
696
-
697
- def clear_conversation():
698
- """Clear the conversation"""
699
- try:
700
- result = diarization_system.clear_conversation()
701
- return f"🗑️ {result}"
702
- except Exception as e:
703
- return f"❌ Failed to clear conversation: {str(e)}"
704
-
705
- def update_settings(threshold, max_speakers):
706
- """Update system settings"""
707
- try:
708
- result = diarization_system.update_settings(threshold, max_speakers)
709
- return f"⚙️ {result}"
710
- except Exception as e:
711
- return f"❌ Failed to update settings: {str(e)}"
712
-
713
- def get_conversation():
714
- """Get the current conversation"""
715
- try:
716
- return diarization_system.get_formatted_conversation()
717
- except Exception as e:
718
- return f"<i>Error getting conversation: {str(e)}</i>"
719
-
720
- def get_status():
721
- """Get system status"""
722
- try:
723
- return diarization_system.get_status_info()
724
- except Exception as e:
725
- return f"Error getting status: {str(e)}"
726
-
727
- # Create Gradio interface
728
- def create_interface():
729
- with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
730
- gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
731
- gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
732
-
733
- with gr.Row():
734
- with gr.Column(scale=2):
735
- # Conversation display
736
- conversation_output = gr.HTML(
737
- value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
738
- label="Live Conversation"
739
- )
740
-
741
- # Control buttons
742
- with gr.Row():
743
- init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
744
- start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
745
- stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
746
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
747
-
748
- # Status display
749
- status_output = gr.Textbox(
750
- label="System Status",
751
- value="Ready to initialize...",
752
- lines=8,
753
- interactive=False
754
- )
755
-
756
- with gr.Column(scale=1):
757
- # Settings
758
- gr.Markdown("## ⚙️ Settings")
759
-
760
- threshold_slider = gr.Slider(
761
- minimum=0.3,
762
- maximum=0.9,
763
- step=0.05,
764
- value=DEFAULT_CHANGE_THRESHOLD,
765
- label="Speaker Change Sensitivity",
766
- info="Lower = more sensitive"
767
- )
768
-
769
- max_speakers_slider = gr.Slider(
770
- minimum=2,
771
- maximum=ABSOLUTE_MAX_SPEAKERS,
772
- step=1,
773
- value=DEFAULT_MAX_SPEAKERS,
774
- label="Maximum Speakers"
775
- )
776
-
777
- update_btn = gr.Button("Update Settings", variant="secondary")
778
-
779
- # Instructions
780
- gr.Markdown("""
781
- ## 📋 Instructions
782
- 1. **Initialize** the system (loads AI models)
783
- 2. **Start** recording
784
- 3. **Speak** - system will transcribe and identify speakers
785
- 4. **Monitor** real-time results below
786
-
787
- ## 🎨 Speaker Colors
788
- - 🔴 Speaker 1 (Red)
789
- - 🟢 Speaker 2 (Teal)
790
- - 🔵 Speaker 3 (Blue)
791
- - 🟡 Speaker 4 (Green)
792
- - 🟣 Speaker 5 (Yellow)
793
- - 🟤 Speaker 6 (Plum)
794
- - 🟫 Speaker 7 (Mint)
795
- - 🟨 Speaker 8 (Gold)
796
- """)
797
-
798
- # Event handlers
799
- def on_initialize():
800
- result = initialize_system()
801
- if "✅" in result:
802
- return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
803
- else:
804
- return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
805
-
806
- def on_start():
807
- result = start_recording()
808
- return result, gr.update(interactive=False), gr.update(interactive=True)
809
-
810
- def on_stop():
811
- result = stop_recording()
812
- return result, gr.update(interactive=True), gr.update(interactive=False)
813
-
814
- def on_clear():
815
- result = clear_conversation()
816
- return result
817
-
818
- def on_update_settings(threshold, max_speakers):
819
- result = update_settings(threshold, int(max_speakers))
820
- return result
821
-
822
- def refresh_conversation():
823
- return get_conversation()
824
-
825
- def refresh_status():
826
- return get_status()
827
-
828
- # Button click handlers
829
- init_btn.click(
830
- fn=on_initialize,
831
- outputs=[status_output, start_btn, stop_btn, clear_btn]
832
- )
833
-
834
- start_btn.click(
835
- fn=on_start,
836
- outputs=[status_output, start_btn, stop_btn]
837
- )
838
-
839
- stop_btn.click(
840
- fn=on_stop,
841
- outputs=[status_output, start_btn, stop_btn]
842
- )
843
-
844
- clear_btn.click(
845
- fn=on_clear,
846
- outputs=[status_output]
847
- )
848
-
849
- update_btn.click(
850
- fn=on_update_settings,
851
- inputs=[threshold_slider, max_speakers_slider],
852
- outputs=[status_output]
853
- )
854
-
855
- # Auto-refresh conversation display every 1 second
856
- conversation_timer = gr.Timer(1)
857
- conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
858
-
859
- # Auto-refresh status every 2 seconds
860
- status_timer = gr.Timer(2)
861
- status_timer.tick(refresh_status, outputs=[status_output])
862
-
863
- return interface
864
-
865
-
866
- # FastAPI setup for FastRTC integration
867
  app = FastAPI()
868
 
869
  @app.get("/")
870
- async def root():
871
- return {"message": "Real-time Speaker Diarization API"}
872
-
873
- @app.get("/health")
874
- async def health_check():
875
- return {"status": "healthy", "system_running": diarization_system.is_running}
876
-
877
- @app.post("/initialize")
878
- async def api_initialize():
879
- result = initialize_system()
880
- return {"result": result, "success": "✅" in result}
881
-
882
- @app.post("/start")
883
- async def api_start():
884
- result = start_recording()
885
- return {"result": result, "success": "🎙️" in result}
886
-
887
- @app.post("/stop")
888
- async def api_stop():
889
- result = stop_recording()
890
- return {"result": result, "success": "⏹️" in result}
891
-
892
- @app.post("/clear")
893
- async def api_clear():
894
- result = clear_conversation()
895
- return {"result": result}
896
-
897
- @app.get("/conversation")
898
- async def api_get_conversation():
899
- return {"conversation": get_conversation()}
900
-
901
- @app.get("/status")
902
- async def api_get_status():
903
- return {"status": get_status()}
904
-
905
- @app.post("/settings")
906
- async def api_update_settings(threshold: float, max_speakers: int):
907
- result = update_settings(threshold, max_speakers)
908
- return {"result": result}
909
-
910
- # FastRTC Stream setup
911
- if audio_handler:
912
- stream = Stream(handler=audio_handler)
913
- app.include_router(stream.router, prefix="/stream")
914
-
915
-
916
- # Main execution
917
- if __name__ == "__main__":
918
- import argparse
919
-
920
- parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
921
- parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
922
- help="Run mode: gradio interface, API only, or both")
923
- parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
924
- parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
925
- parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
926
-
927
- args = parser.parse_args()
928
-
929
- if args.mode == "gradio":
930
- # Run Gradio interface only
931
- interface = create_interface()
932
- interface.launch(
933
- server_name=args.host,
934
- server_port=args.port,
935
- share=True,
936
- show_error=True
937
- )
938
-
939
- elif args.mode == "api":
940
- # Run FastAPI only
941
- uvicorn.run(
942
- app,
943
- host=args.host,
944
- port=args.port,
945
- log_level="info"
946
- )
947
-
948
- elif args.mode == "both":
949
- # Run both Gradio and FastAPI
950
- import multiprocessing
951
- import threading
952
-
953
- def run_gradio():
954
- interface = create_interface()
955
- interface.launch(
956
- server_name=args.host,
957
- server_port=args.port,
958
- share=True,
959
- show_error=True
960
- )
961
-
962
- def run_fastapi():
963
- uvicorn.run(
964
- app,
965
- host=args.host,
966
- port=args.api_port,
967
- log_level="info"
968
- )
969
-
970
- # Start FastAPI in a separate thread
971
- api_thread = threading.Thread(target=run_fastapi, daemon=True)
972
- api_thread.start()
973
-
974
- # Start Gradio in main thread
975
- run_gradio()
 
1
+ from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  app = FastAPI()
4
 
5
  @app.get("/")
6
+ def greet_json():
7
+ return {"Hello": "World!"}