Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torchaudio | |
import time | |
import os | |
import urllib.request | |
from scipy.spatial.distance import cosine | |
import threading | |
import queue | |
from collections import deque | |
import asyncio | |
from typing import Generator, Tuple, List, Optional | |
# Configuration parameters (keeping original models) | |
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" | |
FINAL_BEAM_SIZE = 5 | |
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" | |
REALTIME_BEAM_SIZE = 5 | |
TRANSCRIPTION_LANGUAGE = "en" | |
SILERO_SENSITIVITY = 0.4 | |
WEBRTC_SENSITIVITY = 3 | |
MIN_LENGTH_OF_RECORDING = 0.7 | |
PRE_RECORDING_BUFFER_DURATION = 0.35 | |
# Speaker change detection parameters | |
DEFAULT_CHANGE_THRESHOLD = 0.7 | |
EMBEDDING_HISTORY_SIZE = 5 | |
MIN_SEGMENT_DURATION = 1.0 | |
DEFAULT_MAX_SPEAKERS = 4 | |
ABSOLUTE_MAX_SPEAKERS = 10 | |
SAMPLE_RATE = 16000 | |
# Speaker labels | |
SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)] | |
class SpeechBrainEncoder: | |
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.model = None | |
self.embedding_dim = 192 | |
self.model_loaded = False | |
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") | |
os.makedirs(self.cache_dir, exist_ok=True) | |
def load_model(self): | |
"""Load the ECAPA-TDNN model""" | |
try: | |
from speechbrain.pretrained import EncoderClassifier | |
self.model = EncoderClassifier.from_hparams( | |
source="speechbrain/spkrec-ecapa-voxceleb", | |
savedir=self.cache_dir, | |
run_opts={"device": self.device} | |
) | |
self.model_loaded = True | |
return True | |
except Exception as e: | |
print(f"Error loading ECAPA-TDNN model: {e}") | |
return False | |
def embed_utterance(self, audio, sr=16000): | |
"""Extract speaker embedding from audio""" | |
if not self.model_loaded: | |
raise ValueError("Model not loaded. Call load_model() first.") | |
try: | |
if isinstance(audio, np.ndarray): | |
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) | |
else: | |
waveform = audio.unsqueeze(0) | |
if sr != 16000: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) | |
with torch.no_grad(): | |
embedding = self.model.encode_batch(waveform) | |
return embedding.squeeze().cpu().numpy() | |
except Exception as e: | |
print(f"Error extracting embedding: {e}") | |
return np.zeros(self.embedding_dim) | |
class SpeakerChangeDetector: | |
"""Speaker change detector that supports configurable number of speakers""" | |
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
self.embedding_dim = embedding_dim | |
self.change_threshold = change_threshold | |
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
self.current_speaker = 0 | |
self.previous_embeddings = [] | |
self.last_change_time = time.time() | |
self.mean_embeddings = [None] * self.max_speakers | |
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
self.last_similarity = 0.0 | |
self.active_speakers = set([0]) | |
def set_max_speakers(self, max_speakers): | |
"""Update the maximum number of speakers""" | |
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
if new_max < self.max_speakers: | |
for speaker_id in list(self.active_speakers): | |
if speaker_id >= new_max: | |
self.active_speakers.discard(speaker_id) | |
if self.current_speaker >= new_max: | |
self.current_speaker = 0 | |
if new_max > self.max_speakers: | |
self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) | |
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) | |
else: | |
self.mean_embeddings = self.mean_embeddings[:new_max] | |
self.speaker_embeddings = self.speaker_embeddings[:new_max] | |
self.max_speakers = new_max | |
def set_change_threshold(self, threshold): | |
"""Update the threshold for detecting speaker changes""" | |
self.change_threshold = max(0.1, min(threshold, 0.99)) | |
def add_embedding(self, embedding, timestamp=None): | |
"""Add a new embedding and check if there's a speaker change""" | |
current_time = timestamp or time.time() | |
if not self.previous_embeddings: | |
self.previous_embeddings.append(embedding) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
if self.mean_embeddings[self.current_speaker] is None: | |
self.mean_embeddings[self.current_speaker] = embedding.copy() | |
return self.current_speaker, 1.0 | |
current_mean = self.mean_embeddings[self.current_speaker] | |
if current_mean is not None: | |
similarity = 1.0 - cosine(embedding, current_mean) | |
else: | |
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) | |
self.last_similarity = similarity | |
time_since_last_change = current_time - self.last_change_time | |
is_speaker_change = False | |
if time_since_last_change >= MIN_SEGMENT_DURATION: | |
if similarity < self.change_threshold: | |
best_speaker = self.current_speaker | |
best_similarity = similarity | |
for speaker_id in range(self.max_speakers): | |
if speaker_id == self.current_speaker: | |
continue | |
speaker_mean = self.mean_embeddings[speaker_id] | |
if speaker_mean is not None: | |
speaker_similarity = 1.0 - cosine(embedding, speaker_mean) | |
if speaker_similarity > best_similarity: | |
best_similarity = speaker_similarity | |
best_speaker = speaker_id | |
if best_speaker != self.current_speaker: | |
is_speaker_change = True | |
self.current_speaker = best_speaker | |
elif len(self.active_speakers) < self.max_speakers: | |
for new_id in range(self.max_speakers): | |
if new_id not in self.active_speakers: | |
is_speaker_change = True | |
self.current_speaker = new_id | |
self.active_speakers.add(new_id) | |
break | |
if is_speaker_change: | |
self.last_change_time = current_time | |
self.previous_embeddings.append(embedding) | |
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: | |
self.previous_embeddings.pop(0) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
self.active_speakers.add(self.current_speaker) | |
if len(self.speaker_embeddings[self.current_speaker]) > 30: | |
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] | |
if self.speaker_embeddings[self.current_speaker]: | |
self.mean_embeddings[self.current_speaker] = np.mean( | |
self.speaker_embeddings[self.current_speaker], axis=0 | |
) | |
return self.current_speaker, similarity | |
class AudioProcessor: | |
"""Processes audio data to extract speaker embeddings""" | |
def __init__(self, encoder): | |
self.encoder = encoder | |
def extract_embedding(self, audio_data): | |
try: | |
# Ensure audio is float32 and normalized | |
if audio_data.dtype != np.float32: | |
audio_data = audio_data.astype(np.float32) | |
# Normalize if needed | |
if np.abs(audio_data).max() > 1.0: | |
audio_data = audio_data / np.abs(audio_data).max() | |
# Extract embedding using the loaded encoder | |
embedding = self.encoder.embed_utterance(audio_data) | |
return embedding | |
except Exception as e: | |
print(f"Embedding extraction error: {e}") | |
return np.zeros(self.encoder.embedding_dim) | |
class RealTimeSpeakerDiarization: | |
"""Main class for real-time speaker diarization""" | |
def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
self.encoder = None | |
self.audio_processor = None | |
self.speaker_detector = None | |
self.change_threshold = change_threshold | |
self.max_speakers = max_speakers | |
self.transcript_history = [] | |
self.is_initialized = False | |
# Threading components | |
self.audio_queue = queue.Queue() | |
self.processing_thread = None | |
self.running = False | |
async def initialize(self): | |
"""Initialize the speaker diarization system""" | |
if self.is_initialized: | |
return True | |
try: | |
device_str = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Initializing ECAPA-TDNN model on {device_str}...") | |
self.encoder = SpeechBrainEncoder(device=device_str) | |
success = self.encoder.load_model() | |
if not success: | |
return False | |
self.audio_processor = AudioProcessor(self.encoder) | |
self.speaker_detector = SpeakerChangeDetector( | |
embedding_dim=self.encoder.embedding_dim, | |
change_threshold=self.change_threshold, | |
max_speakers=self.max_speakers | |
) | |
self.is_initialized = True | |
print("Speaker diarization system initialized successfully!") | |
return True | |
except Exception as e: | |
print(f"Initialization error: {e}") | |
return False | |
def update_settings(self, change_threshold, max_speakers): | |
"""Update diarization settings""" | |
self.change_threshold = change_threshold | |
self.max_speakers = max_speakers | |
if self.speaker_detector: | |
self.speaker_detector.set_change_threshold(change_threshold) | |
self.speaker_detector.set_max_speakers(max_speakers) | |
def process_audio_segment(self, audio_data: np.ndarray, text: str) -> Tuple[int, str]: | |
"""Process an audio segment and return speaker ID and formatted text""" | |
if not self.is_initialized: | |
return 0, text | |
try: | |
# Extract speaker embedding | |
embedding = self.audio_processor.extract_embedding(audio_data) | |
# Detect speaker | |
speaker_id, similarity = self.speaker_detector.add_embedding(embedding) | |
# Format text with speaker label | |
speaker_label = SPEAKER_LABELS[speaker_id] | |
formatted_text = f"{speaker_label}: {text}" | |
return speaker_id, formatted_text | |
except Exception as e: | |
print(f"Error processing audio segment: {e}") | |
return 0, f"Speaker 1: {text}" | |
def get_transcript_history(self): | |
"""Get the formatted transcript history""" | |
return "\n".join(self.transcript_history) | |
def add_to_transcript(self, formatted_text: str): | |
"""Add formatted text to transcript history""" | |
self.transcript_history.append(formatted_text) | |
# Keep only last 50 entries to prevent memory issues | |
if len(self.transcript_history) > 50: | |
self.transcript_history = self.transcript_history[-50:] | |
def clear_transcript(self): | |
"""Clear transcript history and reset speaker detector""" | |
self.transcript_history = [] | |
if self.speaker_detector: | |
self.speaker_detector = SpeakerChangeDetector( | |
embedding_dim=self.encoder.embedding_dim, | |
change_threshold=self.change_threshold, | |
max_speakers=self.max_speakers | |
) | |
# Global instance | |
diarization_system = RealTimeSpeakerDiarization() | |
async def initialize_system(): | |
"""Initialize the diarization system""" | |
success = await diarization_system.initialize() | |
if success: | |
return "✅ Speaker diarization system initialized successfully!" | |
else: | |
return "❌ Failed to initialize speaker diarization system. Please check your setup." | |
def process_audio_with_transcript(audio_data, sample_rate, transcription_text, change_threshold, max_speakers): | |
"""Process audio with transcription for speaker diarization""" | |
if not diarization_system.is_initialized: | |
return "Please initialize the system first.", "" | |
if audio_data is None or transcription_text.strip() == "": | |
return diarization_system.get_transcript_history(), "" | |
try: | |
# Update settings | |
diarization_system.update_settings(change_threshold, max_speakers) | |
# Convert audio to the right format | |
if len(audio_data.shape) > 1: | |
audio_data = audio_data.mean(axis=1) # Convert to mono | |
# Resample if needed | |
if sample_rate != SAMPLE_RATE: | |
audio_data = torchaudio.functional.resample( | |
torch.tensor(audio_data), sample_rate, SAMPLE_RATE | |
).numpy() | |
# Process the audio segment | |
speaker_id, formatted_text = diarization_system.process_audio_segment(audio_data, transcription_text) | |
# Add to transcript | |
diarization_system.add_to_transcript(formatted_text) | |
# Return updated transcript and current speaker info | |
transcript = diarization_system.get_transcript_history() | |
current_speaker_info = f"Current Speaker: {SPEAKER_LABELS[speaker_id]}" | |
return transcript, current_speaker_info | |
except Exception as e: | |
error_msg = f"Error processing audio: {str(e)}" | |
return diarization_system.get_transcript_history(), error_msg | |
def clear_conversation(): | |
"""Clear the conversation transcript""" | |
diarization_system.clear_transcript() | |
return "", "Conversation cleared." | |
def create_gradio_interface(): | |
"""Create and return the Gradio interface""" | |
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🎙️ Real-time Speaker Diarization with ASR") | |
gr.Markdown("Upload audio with transcription to perform real-time speaker diarization.") | |
# Initialization section | |
with gr.Row(): | |
init_btn = gr.Button("🚀 Initialize System", variant="primary") | |
init_status = gr.Textbox(label="Initialization Status", interactive=False) | |
# Settings section | |
with gr.Row(): | |
with gr.Column(): | |
change_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
value=DEFAULT_CHANGE_THRESHOLD, | |
step=0.05, | |
label="Speaker Change Threshold", | |
info="Lower values = more sensitive to speaker changes" | |
) | |
with gr.Column(): | |
max_speakers = gr.Slider( | |
minimum=2, | |
maximum=ABSOLUTE_MAX_SPEAKERS, | |
value=DEFAULT_MAX_SPEAKERS, | |
step=1, | |
label="Maximum Number of Speakers", | |
info="Maximum number of speakers to detect" | |
) | |
# Audio input and transcription | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
label="Audio Input", | |
type="numpy", | |
format="wav" | |
) | |
transcription_input = gr.Textbox( | |
label="Transcription Text", | |
placeholder="Enter the transcription of the audio...", | |
lines=3 | |
) | |
process_btn = gr.Button("🎯 Process Audio", variant="secondary") | |
with gr.Column(): | |
current_speaker = gr.Textbox( | |
label="Current Speaker", | |
interactive=False | |
) | |
clear_btn = gr.Button("🗑️ Clear Conversation", variant="stop") | |
# Output section | |
transcript_output = gr.Textbox( | |
label="Live Transcript with Speaker Labels", | |
lines=15, | |
max_lines=20, | |
interactive=False, | |
placeholder="Processed transcript will appear here..." | |
) | |
# Event handlers | |
init_btn.click( | |
fn=initialize_system, | |
outputs=[init_status] | |
) | |
process_btn.click( | |
fn=process_audio_with_transcript, | |
inputs=[ | |
audio_input, | |
gr.Number(value=SAMPLE_RATE, visible=False), # Hidden sample rate | |
transcription_input, | |
change_threshold, | |
max_speakers | |
], | |
outputs=[transcript_output, current_speaker] | |
) | |
clear_btn.click( | |
fn=clear_conversation, | |
outputs=[transcript_output, current_speaker] | |
) | |
# Auto-process when audio and transcription are provided | |
audio_input.change( | |
fn=process_audio_with_transcript, | |
inputs=[ | |
audio_input, | |
gr.Number(value=SAMPLE_RATE, visible=False), | |
transcription_input, | |
change_threshold, | |
max_speakers | |
], | |
outputs=[transcript_output, current_speaker] | |
) | |
# Instructions | |
gr.Markdown(""" | |
## Instructions: | |
1. **Initialize**: Click "Initialize System" to load the speaker diarization models | |
2. **Upload Audio**: Upload an audio file (WAV format recommended) | |
3. **Add Transcription**: Enter the transcription text for the audio | |
4. **Adjust Settings**: | |
- **Speaker Change Threshold**: Lower values detect speaker changes more easily | |
- **Max Speakers**: Set the maximum number of speakers you expect | |
5. **Process**: Click "Process Audio" or the system will auto-process | |
6. **View Results**: See the transcript with speaker labels (Speaker 1, Speaker 2, etc.) | |
## Tips: | |
- For similar-sounding speakers, increase the threshold (0.6-0.8) | |
- For different-sounding speakers, lower threshold works better (0.3-0.5) | |
- The system maintains speaker consistency across the conversation | |
- Use "Clear Conversation" to reset the speaker memory | |
""") | |
return demo | |
if __name__ == "__main__": | |
# Create and launch the Gradio interface | |
demo = create_gradio_interface() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |