Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torchaudio | |
import time | |
import os | |
import urllib.request | |
import queue | |
import threading | |
from scipy.spatial.distance import cosine | |
from RealtimeSTT import AudioToTextRecorder | |
# Configuration parameters (kept same as original) | |
SILENCE_THRESHS = [0, 0.4] | |
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" | |
FINAL_BEAM_SIZE = 5 | |
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" | |
REALTIME_BEAM_SIZE = 5 | |
TRANSCRIPTION_LANGUAGE = "en" | |
SILERO_SENSITIVITY = 0.4 | |
WEBRTC_SENSITIVITY = 3 | |
MIN_LENGTH_OF_RECORDING = 0.7 | |
PRE_RECORDING_BUFFER_DURATION = 0.35 | |
# Speaker change detection parameters | |
DEFAULT_CHANGE_THRESHOLD = 0.7 | |
EMBEDDING_HISTORY_SIZE = 5 | |
MIN_SEGMENT_DURATION = 1.0 | |
DEFAULT_MAX_SPEAKERS = 4 | |
ABSOLUTE_MAX_SPEAKERS = 10 | |
# Audio parameters | |
FAST_SENTENCE_END = True | |
SAMPLE_RATE = 16000 | |
BUFFER_SIZE = 512 | |
CHANNELS = 1 | |
# Speaker colors for HTML display | |
SPEAKER_COLORS = [ | |
"#FFFF00", "#FF0000", "#00FF00", "#00FFFF", "#FF00FF", | |
"#0000FF", "#FF8000", "#00FF80", "#8000FF", "#FFFFFF" | |
] | |
SPEAKER_COLOR_NAMES = [ | |
"Yellow", "Red", "Green", "Cyan", "Magenta", | |
"Blue", "Orange", "Spring Green", "Purple", "White" | |
] | |
class SpeechBrainEncoder: | |
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.model = None | |
self.embedding_dim = 192 | |
self.model_loaded = False | |
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") | |
os.makedirs(self.cache_dir, exist_ok=True) | |
def _download_model(self): | |
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present""" | |
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt" | |
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt") | |
if not os.path.exists(model_path): | |
print(f"Downloading ECAPA-TDNN model to {model_path}...") | |
urllib.request.urlretrieve(model_url, model_path) | |
return model_path | |
def load_model(self): | |
"""Load the ECAPA-TDNN model""" | |
try: | |
from speechbrain.pretrained import EncoderClassifier | |
model_path = self._download_model() | |
self.model = EncoderClassifier.from_hparams( | |
source="speechbrain/spkrec-ecapa-voxceleb", | |
savedir=self.cache_dir, | |
run_opts={"device": self.device} | |
) | |
self.model_loaded = True | |
return True | |
except Exception as e: | |
print(f"Error loading ECAPA-TDNN model: {e}") | |
return False | |
def embed_utterance(self, audio, sr=16000): | |
"""Extract speaker embedding from audio""" | |
if not self.model_loaded: | |
raise ValueError("Model not loaded. Call load_model() first.") | |
try: | |
if isinstance(audio, np.ndarray): | |
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) | |
else: | |
waveform = audio.unsqueeze(0) | |
if sr != 16000: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) | |
with torch.no_grad(): | |
embedding = self.model.encode_batch(waveform) | |
return embedding.squeeze().cpu().numpy() | |
except Exception as e: | |
print(f"Error extracting embedding: {e}") | |
return np.zeros(self.embedding_dim) | |
class AudioProcessor: | |
"""Processes audio data to extract speaker embeddings""" | |
def __init__(self, encoder): | |
self.encoder = encoder | |
def extract_embedding(self, audio_int16): | |
try: | |
float_audio = audio_int16.astype(np.float32) / 32768.0 | |
if np.abs(float_audio).max() > 1.0: | |
float_audio = float_audio / np.abs(float_audio).max() | |
embedding = self.encoder.embed_utterance(float_audio) | |
return embedding | |
except Exception as e: | |
print(f"Embedding extraction error: {e}") | |
return np.zeros(self.encoder.embedding_dim) | |
class SpeakerChangeDetector: | |
"""Speaker change detector with configurable number of speakers""" | |
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
self.embedding_dim = embedding_dim | |
self.change_threshold = change_threshold | |
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
self.current_speaker = 0 | |
self.previous_embeddings = [] | |
self.last_change_time = time.time() | |
self.mean_embeddings = [None] * self.max_speakers | |
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
self.last_similarity = 0.0 | |
self.active_speakers = set([0]) | |
def set_max_speakers(self, max_speakers): | |
"""Update the maximum number of speakers""" | |
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
if new_max < self.max_speakers: | |
for speaker_id in list(self.active_speakers): | |
if speaker_id >= new_max: | |
self.active_speakers.discard(speaker_id) | |
if self.current_speaker >= new_max: | |
self.current_speaker = 0 | |
if new_max > self.max_speakers: | |
self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) | |
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) | |
else: | |
self.mean_embeddings = self.mean_embeddings[:new_max] | |
self.speaker_embeddings = self.speaker_embeddings[:new_max] | |
self.max_speakers = new_max | |
def set_change_threshold(self, threshold): | |
"""Update the threshold for detecting speaker changes""" | |
self.change_threshold = max(0.1, min(threshold, 0.99)) | |
def add_embedding(self, embedding, timestamp=None): | |
"""Add a new embedding and check if there's a speaker change""" | |
current_time = timestamp or time.time() | |
if not self.previous_embeddings: | |
self.previous_embeddings.append(embedding) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
if self.mean_embeddings[self.current_speaker] is None: | |
self.mean_embeddings[self.current_speaker] = embedding.copy() | |
return self.current_speaker, 1.0 | |
current_mean = self.mean_embeddings[self.current_speaker] | |
if current_mean is not None: | |
similarity = 1.0 - cosine(embedding, current_mean) | |
else: | |
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) | |
self.last_similarity = similarity | |
time_since_last_change = current_time - self.last_change_time | |
is_speaker_change = False | |
if time_since_last_change >= MIN_SEGMENT_DURATION: | |
if similarity < self.change_threshold: | |
best_speaker = self.current_speaker | |
best_similarity = similarity | |
for speaker_id in range(self.max_speakers): | |
if speaker_id == self.current_speaker: | |
continue | |
speaker_mean = self.mean_embeddings[speaker_id] | |
if speaker_mean is not None: | |
speaker_similarity = 1.0 - cosine(embedding, speaker_mean) | |
if speaker_similarity > best_similarity: | |
best_similarity = speaker_similarity | |
best_speaker = speaker_id | |
if best_speaker != self.current_speaker: | |
is_speaker_change = True | |
self.current_speaker = best_speaker | |
elif len(self.active_speakers) < self.max_speakers: | |
for new_id in range(self.max_speakers): | |
if new_id not in self.active_speakers: | |
is_speaker_change = True | |
self.current_speaker = new_id | |
self.active_speakers.add(new_id) | |
break | |
if is_speaker_change: | |
self.last_change_time = current_time | |
self.previous_embeddings.append(embedding) | |
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: | |
self.previous_embeddings.pop(0) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
self.active_speakers.add(self.current_speaker) | |
if len(self.speaker_embeddings[self.current_speaker]) > 30: | |
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] | |
if self.speaker_embeddings[self.current_speaker]: | |
self.mean_embeddings[self.current_speaker] = np.mean( | |
self.speaker_embeddings[self.current_speaker], axis=0 | |
) | |
return self.current_speaker, similarity | |
def get_color_for_speaker(self, speaker_id): | |
"""Return color for speaker ID""" | |
if 0 <= speaker_id < len(SPEAKER_COLORS): | |
return SPEAKER_COLORS[speaker_id] | |
return "#FFFFFF" | |
class RealtimeASRDiarization: | |
"""Main class for real-time ASR with speaker diarization""" | |
def __init__(self): | |
self.encoder = None | |
self.audio_processor = None | |
self.speaker_detector = None | |
self.recorder = None | |
self.is_recording = False | |
self.full_sentences = [] | |
self.sentence_speakers = [] | |
self.pending_sentences = [] | |
self.last_realtime_text = "" | |
self.sentence_queue = queue.Queue() | |
self.change_threshold = DEFAULT_CHANGE_THRESHOLD | |
self.max_speakers = DEFAULT_MAX_SPEAKERS | |
# Initialize model | |
self.initialize_model() | |
def initialize_model(self): | |
"""Initialize the speaker encoder model""" | |
try: | |
device_str = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device_str}") | |
self.encoder = SpeechBrainEncoder(device=device_str) | |
success = self.encoder.load_model() | |
if success: | |
print("ECAPA-TDNN model loaded successfully!") | |
self.audio_processor = AudioProcessor(self.encoder) | |
self.speaker_detector = SpeakerChangeDetector( | |
embedding_dim=self.encoder.embedding_dim, | |
change_threshold=self.change_threshold, | |
max_speakers=self.max_speakers | |
) | |
# Start sentence processing thread | |
self.sentence_thread = threading.Thread(target=self.process_sentences, daemon=True) | |
self.sentence_thread.start() | |
else: | |
print("Failed to load ECAPA-TDNN model") | |
except Exception as e: | |
print(f"Model initialization error: {e}") | |
def process_sentences(self): | |
"""Process sentences in background thread""" | |
while True: | |
try: | |
text, audio_bytes = self.sentence_queue.get(timeout=1) | |
self.process_sentence(text, audio_bytes) | |
except queue.Empty: | |
continue | |
def process_sentence(self, text, audio_bytes): | |
"""Process a sentence with speaker diarization""" | |
if self.audio_processor is None or self.speaker_detector is None: | |
return | |
try: | |
# Convert audio data to int16 | |
audio_int16 = np.int16(audio_bytes * 32767) | |
# Extract speaker embedding | |
speaker_embedding = self.audio_processor.extract_embedding(audio_int16) | |
# Store sentence and embedding | |
self.full_sentences.append((text, speaker_embedding)) | |
# Fill in any missing speaker assignments | |
while len(self.sentence_speakers) < len(self.full_sentences) - 1: | |
self.sentence_speakers.append(0) | |
# Detect speaker changes | |
speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding) | |
self.sentence_speakers.append(speaker_id) | |
# Remove from pending | |
if text in self.pending_sentences: | |
self.pending_sentences.remove(text) | |
except Exception as e: | |
print(f"Error processing sentence: {e}") | |
def setup_recorder(self): | |
"""Setup the audio recorder""" | |
try: | |
recorder_config = { | |
'spinner': False, | |
'use_microphone': False, | |
'model': FINAL_TRANSCRIPTION_MODEL, | |
'language': TRANSCRIPTION_LANGUAGE, | |
'silero_sensitivity': SILERO_SENSITIVITY, | |
'webrtc_sensitivity': WEBRTC_SENSITIVITY, | |
'post_speech_silence_duration': SILENCE_THRESHS[1], | |
'min_length_of_recording': MIN_LENGTH_OF_RECORDING, | |
'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION, | |
'min_gap_between_recordings': 0, | |
'enable_realtime_transcription': True, | |
'realtime_processing_pause': 0, | |
'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL, | |
'on_realtime_transcription_update': self.live_text_detected, | |
'beam_size': FINAL_BEAM_SIZE, | |
'beam_size_realtime': REALTIME_BEAM_SIZE, | |
'buffer_size': BUFFER_SIZE, | |
'sample_rate': SAMPLE_RATE, | |
} | |
self.recorder = AudioToTextRecorder(**recorder_config) | |
return True | |
except Exception as e: | |
print(f"Error setting up recorder: {e}") | |
return False | |
def live_text_detected(self, text): | |
"""Handle live text detection""" | |
text = text.strip() | |
if not text: | |
return | |
sentence_delimiters = '.?!。' | |
prob_sentence_end = ( | |
len(self.last_realtime_text) > 0 | |
and text[-1] in sentence_delimiters | |
and self.last_realtime_text[-1] in sentence_delimiters | |
) | |
self.last_realtime_text = text | |
if prob_sentence_end: | |
if FAST_SENTENCE_END: | |
self.recorder.stop() | |
else: | |
self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0] | |
else: | |
self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1] | |
def process_audio_chunk(self, audio_chunk): | |
"""Process incoming audio chunk from FastRTC""" | |
if self.recorder is None: | |
if not self.setup_recorder(): | |
return "Failed to setup recorder" | |
try: | |
# Convert audio to the format expected by the recorder | |
if isinstance(audio_chunk, tuple): | |
sample_rate, audio_data = audio_chunk | |
else: | |
audio_data = audio_chunk | |
sample_rate = SAMPLE_RATE | |
# Ensure audio is in the right format | |
if audio_data.dtype != np.int16: | |
if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: | |
audio_data = (audio_data * 32767).astype(np.int16) | |
else: | |
audio_data = audio_data.astype(np.int16) | |
# Convert to bytes and feed to recorder | |
audio_bytes = audio_data.tobytes() | |
self.recorder.feed_audio(audio_bytes) | |
# Process final text if available | |
def process_final_text(text): | |
text = text.strip() | |
if text: | |
self.pending_sentences.append(text) | |
audio_bytes = self.recorder.last_transcription_bytes | |
self.sentence_queue.put((text, audio_bytes)) | |
# Get transcription | |
self.recorder.text(process_final_text) | |
return self.get_formatted_transcript() | |
except Exception as e: | |
print(f"Error processing audio: {e}") | |
return f"Error: {e}" | |
def get_formatted_transcript(self): | |
"""Get formatted transcript with speaker labels""" | |
try: | |
transcript_parts = [] | |
# Add completed sentences with speaker labels | |
for i, (sentence_text, _) in enumerate(self.full_sentences): | |
if i < len(self.sentence_speakers): | |
speaker_id = self.sentence_speakers[i] | |
speaker_label = f"Speaker {speaker_id + 1}" | |
transcript_parts.append(f"{speaker_label}: {sentence_text}") | |
# Add pending sentences | |
for pending in self.pending_sentences: | |
transcript_parts.append(f"[Processing]: {pending}") | |
# Add current live text | |
if self.last_realtime_text: | |
transcript_parts.append(f"[Live]: {self.last_realtime_text}") | |
return "\n".join(transcript_parts) | |
except Exception as e: | |
print(f"Error formatting transcript: {e}") | |
return "Error formatting transcript" | |
def update_settings(self, change_threshold, max_speakers): | |
"""Update diarization settings""" | |
self.change_threshold = change_threshold | |
self.max_speakers = max_speakers | |
if self.speaker_detector: | |
self.speaker_detector.set_change_threshold(change_threshold) | |
self.speaker_detector.set_max_speakers(max_speakers) | |
def clear_transcript(self): | |
"""Clear all transcript data""" | |
self.full_sentences = [] | |
self.sentence_speakers = [] | |
self.pending_sentences = [] | |
self.last_realtime_text = "" | |
if self.speaker_detector: | |
self.speaker_detector = SpeakerChangeDetector( | |
embedding_dim=self.encoder.embedding_dim, | |
change_threshold=self.change_threshold, | |
max_speakers=self.max_speakers | |
) | |
# Global instance | |
asr_diarization = RealtimeASRDiarization() | |
def process_audio_stream(audio_chunk, change_threshold, max_speakers): | |
"""Process audio stream and return transcript""" | |
# Update settings if changed | |
asr_diarization.update_settings(change_threshold, max_speakers) | |
# Process audio | |
transcript = asr_diarization.process_audio_chunk(audio_chunk) | |
return transcript | |
def clear_transcript(): | |
"""Clear the transcript""" | |
asr_diarization.clear_transcript() | |
return "Transcript cleared. Ready for new input..." | |
def create_interface(): | |
"""Create Gradio interface with FastRTC""" | |
with gr.Blocks(title="Real-time Speaker Diarization") as iface: | |
gr.Markdown("# Real-time ASR with Speaker Diarization") | |
gr.Markdown("Speak into your microphone to see real-time transcription with speaker labels!") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# Audio input with FastRTC | |
audio_input = gr.Audio( | |
sources=["microphone"], | |
streaming=True, | |
label="Microphone Input" | |
) | |
# Transcript output | |
transcript_output = gr.Textbox( | |
label="Live Transcript with Speaker Labels", | |
lines=15, | |
max_lines=20, | |
value="Ready to start transcription...", | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("### Settings") | |
# Speaker change threshold | |
change_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=0.95, | |
value=DEFAULT_CHANGE_THRESHOLD, | |
step=0.05, | |
label="Speaker Change Threshold", | |
info="Lower values = more sensitive to speaker changes" | |
) | |
# Max speakers | |
max_speakers = gr.Slider( | |
minimum=2, | |
maximum=ABSOLUTE_MAX_SPEAKERS, | |
value=DEFAULT_MAX_SPEAKERS, | |
step=1, | |
label="Maximum Speakers", | |
info="Maximum number of speakers to detect" | |
) | |
# Clear button | |
clear_btn = gr.Button("Clear Transcript", variant="secondary") | |
gr.Markdown("### Speaker Colors") | |
color_info = "\\n".join([ | |
f"Speaker {i+1}: {SPEAKER_COLOR_NAMES[i]}" | |
for i in range(min(DEFAULT_MAX_SPEAKERS, len(SPEAKER_COLOR_NAMES))) | |
]) | |
gr.Markdown(color_info) | |
# Set up streaming | |
audio_input.stream( | |
fn=process_audio_stream, | |
inputs=[audio_input, change_threshold, max_speakers], | |
outputs=[transcript_output], | |
show_progress=False | |
) | |
# Clear button functionality | |
clear_btn.click( | |
fn=clear_transcript, | |
outputs=[transcript_output] | |
) | |
gr.Markdown(""" | |
### Instructions: | |
1. Allow microphone access when prompted | |
2. Start speaking - transcription will appear in real-time | |
3. Different speakers will be automatically detected and labeled | |
4. Adjust the threshold if speaker changes aren't detected properly | |
5. Use the clear button to reset the transcript | |
### Notes: | |
- The system works best with clear audio and distinct speakers | |
- It may take a moment to load the speaker recognition model on first use | |
- Lower threshold values make the system more sensitive to speaker changes | |
""") | |
return iface | |
if __name__ == "__main__": | |
# Create and launch the interface | |
iface = create_interface() | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) | |