Saiyaswanth007's picture
Fixing gradio RealStream
b9dea2c
raw
history blame
19.6 kB
import gradio as gr
import numpy as np
import torch
import torchaudio
import time
import os
import urllib.request
from scipy.spatial.distance import cosine
import threading
import queue
from collections import deque
import asyncio
from typing import Generator, Tuple, List, Optional
# Configuration parameters (keeping original models)
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
FINAL_BEAM_SIZE = 5
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
REALTIME_BEAM_SIZE = 5
TRANSCRIPTION_LANGUAGE = "en"
SILERO_SENSITIVITY = 0.4
WEBRTC_SENSITIVITY = 3
MIN_LENGTH_OF_RECORDING = 0.7
PRE_RECORDING_BUFFER_DURATION = 0.35
# Speaker change detection parameters
DEFAULT_CHANGE_THRESHOLD = 0.7
EMBEDDING_HISTORY_SIZE = 5
MIN_SEGMENT_DURATION = 1.0
DEFAULT_MAX_SPEAKERS = 4
ABSOLUTE_MAX_SPEAKERS = 10
SAMPLE_RATE = 16000
# Speaker labels
SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)]
class SpeechBrainEncoder:
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
def __init__(self, device="cpu"):
self.device = device
self.model = None
self.embedding_dim = 192
self.model_loaded = False
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
os.makedirs(self.cache_dir, exist_ok=True)
def load_model(self):
"""Load the ECAPA-TDNN model"""
try:
from speechbrain.pretrained import EncoderClassifier
self.model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=self.cache_dir,
run_opts={"device": self.device}
)
self.model_loaded = True
return True
except Exception as e:
print(f"Error loading ECAPA-TDNN model: {e}")
return False
def embed_utterance(self, audio, sr=16000):
"""Extract speaker embedding from audio"""
if not self.model_loaded:
raise ValueError("Model not loaded. Call load_model() first.")
try:
if isinstance(audio, np.ndarray):
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
else:
waveform = audio.unsqueeze(0)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
with torch.no_grad():
embedding = self.model.encode_batch(waveform)
return embedding.squeeze().cpu().numpy()
except Exception as e:
print(f"Error extracting embedding: {e}")
return np.zeros(self.embedding_dim)
class SpeakerChangeDetector:
"""Speaker change detector that supports configurable number of speakers"""
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
self.embedding_dim = embedding_dim
self.change_threshold = change_threshold
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
self.current_speaker = 0
self.previous_embeddings = []
self.last_change_time = time.time()
self.mean_embeddings = [None] * self.max_speakers
self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
self.last_similarity = 0.0
self.active_speakers = set([0])
def set_max_speakers(self, max_speakers):
"""Update the maximum number of speakers"""
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
if new_max < self.max_speakers:
for speaker_id in list(self.active_speakers):
if speaker_id >= new_max:
self.active_speakers.discard(speaker_id)
if self.current_speaker >= new_max:
self.current_speaker = 0
if new_max > self.max_speakers:
self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
else:
self.mean_embeddings = self.mean_embeddings[:new_max]
self.speaker_embeddings = self.speaker_embeddings[:new_max]
self.max_speakers = new_max
def set_change_threshold(self, threshold):
"""Update the threshold for detecting speaker changes"""
self.change_threshold = max(0.1, min(threshold, 0.99))
def add_embedding(self, embedding, timestamp=None):
"""Add a new embedding and check if there's a speaker change"""
current_time = timestamp or time.time()
if not self.previous_embeddings:
self.previous_embeddings.append(embedding)
self.speaker_embeddings[self.current_speaker].append(embedding)
if self.mean_embeddings[self.current_speaker] is None:
self.mean_embeddings[self.current_speaker] = embedding.copy()
return self.current_speaker, 1.0
current_mean = self.mean_embeddings[self.current_speaker]
if current_mean is not None:
similarity = 1.0 - cosine(embedding, current_mean)
else:
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
self.last_similarity = similarity
time_since_last_change = current_time - self.last_change_time
is_speaker_change = False
if time_since_last_change >= MIN_SEGMENT_DURATION:
if similarity < self.change_threshold:
best_speaker = self.current_speaker
best_similarity = similarity
for speaker_id in range(self.max_speakers):
if speaker_id == self.current_speaker:
continue
speaker_mean = self.mean_embeddings[speaker_id]
if speaker_mean is not None:
speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
if speaker_similarity > best_similarity:
best_similarity = speaker_similarity
best_speaker = speaker_id
if best_speaker != self.current_speaker:
is_speaker_change = True
self.current_speaker = best_speaker
elif len(self.active_speakers) < self.max_speakers:
for new_id in range(self.max_speakers):
if new_id not in self.active_speakers:
is_speaker_change = True
self.current_speaker = new_id
self.active_speakers.add(new_id)
break
if is_speaker_change:
self.last_change_time = current_time
self.previous_embeddings.append(embedding)
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
self.previous_embeddings.pop(0)
self.speaker_embeddings[self.current_speaker].append(embedding)
self.active_speakers.add(self.current_speaker)
if len(self.speaker_embeddings[self.current_speaker]) > 30:
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
if self.speaker_embeddings[self.current_speaker]:
self.mean_embeddings[self.current_speaker] = np.mean(
self.speaker_embeddings[self.current_speaker], axis=0
)
return self.current_speaker, similarity
class AudioProcessor:
"""Processes audio data to extract speaker embeddings"""
def __init__(self, encoder):
self.encoder = encoder
def extract_embedding(self, audio_data):
try:
# Ensure audio is float32 and normalized
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Normalize if needed
if np.abs(audio_data).max() > 1.0:
audio_data = audio_data / np.abs(audio_data).max()
# Extract embedding using the loaded encoder
embedding = self.encoder.embed_utterance(audio_data)
return embedding
except Exception as e:
print(f"Embedding extraction error: {e}")
return np.zeros(self.encoder.embedding_dim)
class RealTimeSpeakerDiarization:
"""Main class for real-time speaker diarization"""
def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
self.encoder = None
self.audio_processor = None
self.speaker_detector = None
self.change_threshold = change_threshold
self.max_speakers = max_speakers
self.transcript_history = []
self.is_initialized = False
# Threading components
self.audio_queue = queue.Queue()
self.processing_thread = None
self.running = False
async def initialize(self):
"""Initialize the speaker diarization system"""
if self.is_initialized:
return True
try:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing ECAPA-TDNN model on {device_str}...")
self.encoder = SpeechBrainEncoder(device=device_str)
success = self.encoder.load_model()
if not success:
return False
self.audio_processor = AudioProcessor(self.encoder)
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
self.is_initialized = True
print("Speaker diarization system initialized successfully!")
return True
except Exception as e:
print(f"Initialization error: {e}")
return False
def update_settings(self, change_threshold, max_speakers):
"""Update diarization settings"""
self.change_threshold = change_threshold
self.max_speakers = max_speakers
if self.speaker_detector:
self.speaker_detector.set_change_threshold(change_threshold)
self.speaker_detector.set_max_speakers(max_speakers)
def process_audio_segment(self, audio_data: np.ndarray, text: str) -> Tuple[int, str]:
"""Process an audio segment and return speaker ID and formatted text"""
if not self.is_initialized:
return 0, text
try:
# Extract speaker embedding
embedding = self.audio_processor.extract_embedding(audio_data)
# Detect speaker
speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
# Format text with speaker label
speaker_label = SPEAKER_LABELS[speaker_id]
formatted_text = f"{speaker_label}: {text}"
return speaker_id, formatted_text
except Exception as e:
print(f"Error processing audio segment: {e}")
return 0, f"Speaker 1: {text}"
def get_transcript_history(self):
"""Get the formatted transcript history"""
return "\n".join(self.transcript_history)
def add_to_transcript(self, formatted_text: str):
"""Add formatted text to transcript history"""
self.transcript_history.append(formatted_text)
# Keep only last 50 entries to prevent memory issues
if len(self.transcript_history) > 50:
self.transcript_history = self.transcript_history[-50:]
def clear_transcript(self):
"""Clear transcript history and reset speaker detector"""
self.transcript_history = []
if self.speaker_detector:
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
# Global instance
diarization_system = RealTimeSpeakerDiarization()
async def initialize_system():
"""Initialize the diarization system"""
success = await diarization_system.initialize()
if success:
return "✅ Speaker diarization system initialized successfully!"
else:
return "❌ Failed to initialize speaker diarization system. Please check your setup."
def process_audio_with_transcript(audio_data, sample_rate, transcription_text, change_threshold, max_speakers):
"""Process audio with transcription for speaker diarization"""
if not diarization_system.is_initialized:
return "Please initialize the system first.", ""
if audio_data is None or transcription_text.strip() == "":
return diarization_system.get_transcript_history(), ""
try:
# Update settings
diarization_system.update_settings(change_threshold, max_speakers)
# Convert audio to the right format
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1) # Convert to mono
# Resample if needed
if sample_rate != SAMPLE_RATE:
audio_data = torchaudio.functional.resample(
torch.tensor(audio_data), sample_rate, SAMPLE_RATE
).numpy()
# Process the audio segment
speaker_id, formatted_text = diarization_system.process_audio_segment(audio_data, transcription_text)
# Add to transcript
diarization_system.add_to_transcript(formatted_text)
# Return updated transcript and current speaker info
transcript = diarization_system.get_transcript_history()
current_speaker_info = f"Current Speaker: {SPEAKER_LABELS[speaker_id]}"
return transcript, current_speaker_info
except Exception as e:
error_msg = f"Error processing audio: {str(e)}"
return diarization_system.get_transcript_history(), error_msg
def clear_conversation():
"""Clear the conversation transcript"""
diarization_system.clear_transcript()
return "", "Conversation cleared."
def create_gradio_interface():
"""Create and return the Gradio interface"""
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎙️ Real-time Speaker Diarization with ASR")
gr.Markdown("Upload audio with transcription to perform real-time speaker diarization.")
# Initialization section
with gr.Row():
init_btn = gr.Button("🚀 Initialize System", variant="primary")
init_status = gr.Textbox(label="Initialization Status", interactive=False)
# Settings section
with gr.Row():
with gr.Column():
change_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=DEFAULT_CHANGE_THRESHOLD,
step=0.05,
label="Speaker Change Threshold",
info="Lower values = more sensitive to speaker changes"
)
with gr.Column():
max_speakers = gr.Slider(
minimum=2,
maximum=ABSOLUTE_MAX_SPEAKERS,
value=DEFAULT_MAX_SPEAKERS,
step=1,
label="Maximum Number of Speakers",
info="Maximum number of speakers to detect"
)
# Audio input and transcription
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Audio Input",
type="numpy",
format="wav"
)
transcription_input = gr.Textbox(
label="Transcription Text",
placeholder="Enter the transcription of the audio...",
lines=3
)
process_btn = gr.Button("🎯 Process Audio", variant="secondary")
with gr.Column():
current_speaker = gr.Textbox(
label="Current Speaker",
interactive=False
)
clear_btn = gr.Button("🗑️ Clear Conversation", variant="stop")
# Output section
transcript_output = gr.Textbox(
label="Live Transcript with Speaker Labels",
lines=15,
max_lines=20,
interactive=False,
placeholder="Processed transcript will appear here..."
)
# Event handlers
init_btn.click(
fn=initialize_system,
outputs=[init_status]
)
process_btn.click(
fn=process_audio_with_transcript,
inputs=[
audio_input,
gr.Number(value=SAMPLE_RATE, visible=False), # Hidden sample rate
transcription_input,
change_threshold,
max_speakers
],
outputs=[transcript_output, current_speaker]
)
clear_btn.click(
fn=clear_conversation,
outputs=[transcript_output, current_speaker]
)
# Auto-process when audio and transcription are provided
audio_input.change(
fn=process_audio_with_transcript,
inputs=[
audio_input,
gr.Number(value=SAMPLE_RATE, visible=False),
transcription_input,
change_threshold,
max_speakers
],
outputs=[transcript_output, current_speaker]
)
# Instructions
gr.Markdown("""
## Instructions:
1. **Initialize**: Click "Initialize System" to load the speaker diarization models
2. **Upload Audio**: Upload an audio file (WAV format recommended)
3. **Add Transcription**: Enter the transcription text for the audio
4. **Adjust Settings**:
- **Speaker Change Threshold**: Lower values detect speaker changes more easily
- **Max Speakers**: Set the maximum number of speakers you expect
5. **Process**: Click "Process Audio" or the system will auto-process
6. **View Results**: See the transcript with speaker labels (Speaker 1, Speaker 2, etc.)
## Tips:
- For similar-sounding speakers, increase the threshold (0.6-0.8)
- For different-sounding speakers, lower threshold works better (0.3-0.5)
- The system maintains speaker consistency across the conversation
- Use "Clear Conversation" to reset the speaker memory
""")
return demo
if __name__ == "__main__":
# Create and launch the Gradio interface
demo = create_gradio_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)