Saiyaswanth007's picture
Fixing Real time audio
7208f76
raw
history blame
22.5 kB
import gradio as gr
import numpy as np
import torch
import torchaudio
import time
import os
import urllib.request
from scipy.spatial.distance import cosine
import threading
import queue
from collections import deque
import asyncio
from typing import Generator, Tuple, List, Optional
import whisper
from transformers import pipeline
# Configuration parameters (keeping original models)
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
FINAL_BEAM_SIZE = 5
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
REALTIME_BEAM_SIZE = 5
TRANSCRIPTION_LANGUAGE = "en"
SILERO_SENSITIVITY = 0.4
WEBRTC_SENSITIVITY = 3
MIN_LENGTH_OF_RECORDING = 0.7
PRE_RECORDING_BUFFER_DURATION = 0.35
# Speaker change detection parameters
DEFAULT_CHANGE_THRESHOLD = 0.7
EMBEDDING_HISTORY_SIZE = 5
MIN_SEGMENT_DURATION = 1.0
DEFAULT_MAX_SPEAKERS = 4
ABSOLUTE_MAX_SPEAKERS = 10
SAMPLE_RATE = 16000
CHUNK_DURATION = 2.0 # Process audio in 2-second chunks
# Speaker labels
SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)]
class SpeechBrainEncoder:
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
def __init__(self, device="cpu"):
self.device = device
self.model = None
self.embedding_dim = 192
self.model_loaded = False
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
os.makedirs(self.cache_dir, exist_ok=True)
def load_model(self):
"""Load the ECAPA-TDNN model"""
try:
from speechbrain.pretrained import EncoderClassifier
self.model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=self.cache_dir,
run_opts={"device": self.device}
)
self.model_loaded = True
return True
except Exception as e:
print(f"Error loading ECAPA-TDNN model: {e}")
return False
def embed_utterance(self, audio, sr=16000):
"""Extract speaker embedding from audio"""
if not self.model_loaded:
raise ValueError("Model not loaded. Call load_model() first.")
try:
if isinstance(audio, np.ndarray):
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
else:
waveform = audio.unsqueeze(0)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
with torch.no_grad():
embedding = self.model.encode_batch(waveform)
return embedding.squeeze().cpu().numpy()
except Exception as e:
print(f"Error extracting embedding: {e}")
return np.zeros(self.embedding_dim)
class SpeakerChangeDetector:
"""Speaker change detector that supports configurable number of speakers"""
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
self.embedding_dim = embedding_dim
self.change_threshold = change_threshold
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
self.current_speaker = 0
self.previous_embeddings = []
self.last_change_time = time.time()
self.mean_embeddings = [None] * self.max_speakers
self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
self.last_similarity = 0.0
self.active_speakers = set([0])
def set_max_speakers(self, max_speakers):
"""Update the maximum number of speakers"""
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
if new_max < self.max_speakers:
for speaker_id in list(self.active_speakers):
if speaker_id >= new_max:
self.active_speakers.discard(speaker_id)
if self.current_speaker >= new_max:
self.current_speaker = 0
if new_max > self.max_speakers:
self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
else:
self.mean_embeddings = self.mean_embeddings[:new_max]
self.speaker_embeddings = self.speaker_embeddings[:new_max]
self.max_speakers = new_max
def set_change_threshold(self, threshold):
"""Update the threshold for detecting speaker changes"""
self.change_threshold = max(0.1, min(threshold, 0.99))
def add_embedding(self, embedding, timestamp=None):
"""Add a new embedding and check if there's a speaker change"""
current_time = timestamp or time.time()
if not self.previous_embeddings:
self.previous_embeddings.append(embedding)
self.speaker_embeddings[self.current_speaker].append(embedding)
if self.mean_embeddings[self.current_speaker] is None:
self.mean_embeddings[self.current_speaker] = embedding.copy()
return self.current_speaker, 1.0
current_mean = self.mean_embeddings[self.current_speaker]
if current_mean is not None:
similarity = 1.0 - cosine(embedding, current_mean)
else:
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
self.last_similarity = similarity
time_since_last_change = current_time - self.last_change_time
is_speaker_change = False
if time_since_last_change >= MIN_SEGMENT_DURATION:
if similarity < self.change_threshold:
best_speaker = self.current_speaker
best_similarity = similarity
for speaker_id in range(self.max_speakers):
if speaker_id == self.current_speaker:
continue
speaker_mean = self.mean_embeddings[speaker_id]
if speaker_mean is not None:
speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
if speaker_similarity > best_similarity:
best_similarity = speaker_similarity
best_speaker = speaker_id
if best_speaker != self.current_speaker:
is_speaker_change = True
self.current_speaker = best_speaker
elif len(self.active_speakers) < self.max_speakers:
for new_id in range(self.max_speakers):
if new_id not in self.active_speakers:
is_speaker_change = True
self.current_speaker = new_id
self.active_speakers.add(new_id)
break
if is_speaker_change:
self.last_change_time = current_time
self.previous_embeddings.append(embedding)
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
self.previous_embeddings.pop(0)
self.speaker_embeddings[self.current_speaker].append(embedding)
self.active_speakers.add(self.current_speaker)
if len(self.speaker_embeddings[self.current_speaker]) > 30:
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
if self.speaker_embeddings[self.current_speaker]:
self.mean_embeddings[self.current_speaker] = np.mean(
self.speaker_embeddings[self.current_speaker], axis=0
)
return self.current_speaker, similarity
class AudioProcessor:
"""Processes audio data to extract speaker embeddings"""
def __init__(self, encoder):
self.encoder = encoder
def extract_embedding(self, audio_data):
try:
# Ensure audio is float32 and normalized
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Normalize if needed
if np.abs(audio_data).max() > 1.0:
audio_data = audio_data / np.abs(audio_data).max()
# Extract embedding using the loaded encoder
embedding = self.encoder.embed_utterance(audio_data)
return embedding
except Exception as e:
print(f"Embedding extraction error: {e}")
return np.zeros(self.encoder.embedding_dim)
class RealTimeSpeakerDiarization:
"""Main class for real-time speaker diarization with FastRTC"""
def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
self.encoder = None
self.audio_processor = None
self.speaker_detector = None
self.transcription_pipeline = None
self.change_threshold = change_threshold
self.max_speakers = max_speakers
self.transcript_history = []
self.is_initialized = False
# Audio processing
self.audio_buffer = deque(maxlen=int(SAMPLE_RATE * 10)) # 10 second buffer
self.processing_queue = queue.Queue()
self.last_processed_time = 0
self.current_transcript = ""
def initialize(self):
"""Initialize the speaker diarization system"""
if self.is_initialized:
return True
try:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing models on {device_str}...")
# Initialize speaker encoder
self.encoder = SpeechBrainEncoder(device=device_str)
success = self.encoder.load_model()
if not success:
return False
# Initialize transcription pipeline
self.transcription_pipeline = pipeline(
"automatic-speech-recognition",
model=f"openai/whisper-{REALTIME_TRANSCRIPTION_MODEL}",
device=0 if torch.cuda.is_available() else -1,
return_timestamps=True
)
self.audio_processor = AudioProcessor(self.encoder)
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
self.is_initialized = True
print("Speaker diarization system initialized successfully!")
return True
except Exception as e:
print(f"Initialization error: {e}")
return False
def update_settings(self, change_threshold, max_speakers):
"""Update diarization settings"""
self.change_threshold = change_threshold
self.max_speakers = max_speakers
if self.speaker_detector:
self.speaker_detector.set_change_threshold(change_threshold)
self.speaker_detector.set_max_speakers(max_speakers)
def process_audio_stream(self, audio_chunk, sample_rate):
"""Process real-time audio stream from FastRTC"""
if not self.is_initialized:
return self.get_current_transcript(), "System not initialized"
try:
# Convert to numpy array if needed
if hasattr(audio_chunk, 'numpy'):
audio_data = audio_chunk.numpy()
else:
audio_data = np.array(audio_chunk)
# Handle different audio formats
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1) # Convert to mono
# Resample if needed
if sample_rate != SAMPLE_RATE:
audio_data = torchaudio.functional.resample(
torch.tensor(audio_data), sample_rate, SAMPLE_RATE
).numpy()
# Add to buffer
self.audio_buffer.extend(audio_data)
# Process if we have enough audio
current_time = time.time()
if (current_time - self.last_processed_time) >= CHUNK_DURATION:
self.process_buffered_audio()
self.last_processed_time = current_time
return self.get_current_transcript(), f"Processing... Buffer: {len(self.audio_buffer)} samples"
except Exception as e:
error_msg = f"Error processing audio stream: {str(e)}"
print(error_msg)
return self.get_current_transcript(), error_msg
def process_buffered_audio(self):
"""Process buffered audio for transcription and speaker diarization"""
if len(self.audio_buffer) < int(SAMPLE_RATE * MIN_LENGTH_OF_RECORDING):
return
try:
# Get audio data from buffer
audio_data = np.array(list(self.audio_buffer))
# Transcribe audio
if len(audio_data) > 0:
result = self.transcription_pipeline(
audio_data,
return_timestamps=True,
generate_kwargs={"language": TRANSCRIPTION_LANGUAGE}
)
transcription = result["text"].strip()
if transcription and len(transcription) > 0:
# Extract speaker embedding
embedding = self.audio_processor.extract_embedding(audio_data)
# Detect speaker
speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
# Format text with speaker label
speaker_label = SPEAKER_LABELS[speaker_id]
formatted_text = f"{speaker_label}: {transcription}"
# Add to transcript
self.add_to_transcript(formatted_text)
print(f"Transcribed: {formatted_text} (Similarity: {similarity:.3f})")
# Clear part of the buffer to prevent memory issues
if len(self.audio_buffer) > SAMPLE_RATE * 5: # Keep last 5 seconds
self.audio_buffer = deque(list(self.audio_buffer)[-SAMPLE_RATE * 3:], maxlen=int(SAMPLE_RATE * 10))
except Exception as e:
print(f"Error in process_buffered_audio: {e}")
def get_current_transcript(self):
"""Get the current transcript"""
return "\n".join(self.transcript_history) if self.transcript_history else "Listening..."
def add_to_transcript(self, formatted_text: str):
"""Add formatted text to transcript history"""
self.transcript_history.append(formatted_text)
# Keep only last 50 entries to prevent memory issues
if len(self.transcript_history) > 50:
self.transcript_history = self.transcript_history[-50:]
def clear_transcript(self):
"""Clear transcript history and reset speaker detector"""
self.transcript_history = []
self.audio_buffer.clear()
if self.speaker_detector:
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
def get_status(self):
"""Get current system status"""
if not self.is_initialized:
return "System not initialized"
if self.speaker_detector:
active_speakers = len(self.speaker_detector.active_speakers)
current_speaker = self.speaker_detector.current_speaker + 1
similarity = self.speaker_detector.last_similarity
return f"Active: {active_speakers} speakers | Current: Speaker {current_speaker} | Similarity: {similarity:.3f}"
return "Ready"
# Global instance
diarization_system = RealTimeSpeakerDiarization()
def initialize_system():
"""Initialize the diarization system"""
success = diarization_system.initialize()
if success:
return "βœ… Speaker diarization system initialized successfully!"
else:
return "❌ Failed to initialize speaker diarization system. Please check your setup."
def process_realtime_audio(audio_stream, change_threshold, max_speakers):
"""Process real-time audio stream from FastRTC"""
if not diarization_system.is_initialized:
return "Please initialize the system first.", "System not ready"
# Update settings
diarization_system.update_settings(change_threshold, max_speakers)
if audio_stream is None:
return diarization_system.get_current_transcript(), diarization_system.get_status()
# Process the audio stream
transcript, status = diarization_system.process_audio_stream(audio_stream, SAMPLE_RATE)
return transcript, diarization_system.get_status()
def clear_conversation():
"""Clear the conversation transcript"""
diarization_system.clear_transcript()
return "Conversation cleared. Listening...", "Ready"
def create_gradio_interface():
"""Create and return the Gradio interface with FastRTC"""
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸŽ™οΈ Real-time Speaker Diarization with FastRTC")
gr.Markdown("Speak into your microphone for real-time speaker diarization and transcription.")
# Initialization section
with gr.Row():
init_btn = gr.Button("πŸš€ Initialize System", variant="primary", scale=1)
init_status = gr.Textbox(label="System Status", interactive=False, scale=2)
# Settings section
with gr.Row():
with gr.Column():
change_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=DEFAULT_CHANGE_THRESHOLD,
step=0.05,
label="Speaker Change Threshold",
info="Lower values = more sensitive to speaker changes"
)
with gr.Column():
max_speakers = gr.Slider(
minimum=2,
maximum=ABSOLUTE_MAX_SPEAKERS,
value=DEFAULT_MAX_SPEAKERS,
step=1,
label="Maximum Number of Speakers",
info="Maximum number of speakers to detect"
)
# FastRTC Audio Input
with gr.Row():
with gr.Column():
# FastRTC component for real-time audio
audio_input = gr.FastRTC(
audio=True,
video=False,
label="🎀 Real-time Audio Input",
audio_sample_rate=SAMPLE_RATE,
audio_channels=1
)
clear_btn = gr.Button("πŸ—‘οΈ Clear Conversation", variant="stop")
with gr.Column():
current_status = gr.Textbox(
label="Current Status",
interactive=False,
value="Click Initialize to start"
)
# Output section
transcript_output = gr.Textbox(
label="πŸ”΄ Live Transcript with Speaker Labels",
lines=15,
max_lines=25,
interactive=False,
value="Click Initialize, then start speaking...",
autoscroll=True
)
# Event handlers
init_btn.click(
fn=initialize_system,
outputs=[init_status]
)
# FastRTC stream processing
audio_input.stream(
fn=process_realtime_audio,
inputs=[audio_input, change_threshold, max_speakers],
outputs=[transcript_output, current_status],
time_limit=30 # Process in 30-second chunks
)
clear_btn.click(
fn=clear_conversation,
outputs=[transcript_output, current_status]
)
# Instructions
with gr.Accordion("πŸ“‹ Instructions", open=False):
gr.Markdown("""
## How to Use:
1. **Initialize**: Click "πŸš€ Initialize System" to load the AI models (this may take a moment)
2. **Allow Microphone**: Your browser will ask for microphone permission - please allow it
3. **Adjust Settings**:
- **Speaker Change Threshold**:
- Lower (0.3-0.5) for speakers with different voices
- Higher (0.6-0.8) for speakers with similar voices
- **Max Speakers**: Set expected number of speakers (2-10)
4. **Start Speaking**: The system will automatically transcribe and identify speakers
5. **View Results**: See real-time transcript with speaker labels (Speaker 1, Speaker 2, etc.)
6. **Clear**: Use "Clear Conversation" to reset and start fresh
## Features:
- βœ… Real-time audio processing via FastRTC
- βœ… Automatic speech recognition with Whisper
- βœ… Speaker diarization with ECAPA-TDNN
- βœ… Live transcript with speaker labels
- βœ… Configurable sensitivity settings
- βœ… Support for up to 10 speakers
## Tips:
- Speak clearly and allow brief pauses between speakers
- The system learns speaker characteristics over time
- Better results with distinct speaker voices
- Ensure good microphone quality for best performance
""")
return demo
if __name__ == "__main__":
# Create and launch the Gradio interface
demo = create_gradio_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)