|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from scipy.spatial.distance import cosine |
|
import tempfile |
|
import os |
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
try: |
|
from transformers import pipeline |
|
except ImportError: |
|
print("transformers not found. Install with: pip install transformers") |
|
|
|
|
|
class Config: |
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
CHANGE_THRESHOLD = 0.65 |
|
MAX_SPEAKERS = 4 |
|
MIN_SEGMENT_DURATION = 1.0 |
|
EMBEDDING_HISTORY_SIZE = 3 |
|
SPEAKER_MEMORY_SIZE = 20 |
|
|
|
|
|
SPEAKER_COLORS = [ |
|
"#FFD700", |
|
"#FF6B6B", |
|
"#4ECDC4", |
|
"#45B7D1", |
|
"#96CEB4", |
|
"#FFEAA7", |
|
"#DDA0DD", |
|
"#98D8C8", |
|
] |
|
|
|
class SpeakerEncoder: |
|
"""Simplified speaker encoder using torchaudio transforms""" |
|
|
|
def __init__(self, device="cpu"): |
|
self.device = device |
|
self.embedding_dim = 128 |
|
self.model_loaded = False |
|
self._setup_model() |
|
|
|
def _setup_model(self): |
|
"""Setup a simple MFCC-based feature extractor""" |
|
try: |
|
self.mfcc_transform = torchaudio.transforms.MFCC( |
|
sample_rate=Config.SAMPLE_RATE, |
|
n_mfcc=13, |
|
melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23} |
|
).to(self.device) |
|
self.model_loaded = True |
|
print("Simple MFCC-based encoder initialized") |
|
except Exception as e: |
|
print(f"Error setting up encoder: {e}") |
|
self.model_loaded = False |
|
|
|
def extract_embedding(self, audio): |
|
"""Extract speaker embedding from audio""" |
|
if not self.model_loaded: |
|
return np.zeros(self.embedding_dim) |
|
|
|
try: |
|
|
|
if isinstance(audio, np.ndarray): |
|
audio = torch.from_numpy(audio).float() |
|
|
|
|
|
if audio.abs().max() > 0: |
|
audio = audio / audio.abs().max() |
|
|
|
|
|
if audio.dim() == 1: |
|
audio = audio.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
mfcc = self.mfcc_transform(audio) |
|
|
|
embedding = torch.cat([ |
|
mfcc.mean(dim=2).flatten(), |
|
mfcc.std(dim=2).flatten(), |
|
mfcc.max(dim=2)[0].flatten(), |
|
mfcc.min(dim=2)[0].flatten() |
|
]) |
|
|
|
|
|
if embedding.size(0) > self.embedding_dim: |
|
embedding = embedding[:self.embedding_dim] |
|
elif embedding.size(0) < self.embedding_dim: |
|
padding = torch.zeros(self.embedding_dim - embedding.size(0)) |
|
embedding = torch.cat([embedding, padding]) |
|
|
|
return embedding.cpu().numpy() |
|
|
|
except Exception as e: |
|
print(f"Error extracting embedding: {e}") |
|
return np.zeros(self.embedding_dim) |
|
|
|
class SpeakerDetector: |
|
"""Speaker change detection using embeddings""" |
|
|
|
def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS): |
|
self.threshold = threshold |
|
self.max_speakers = max_speakers |
|
self.current_speaker = 0 |
|
self.speaker_embeddings = [[] for _ in range(max_speakers)] |
|
self.speaker_centroids = [None] * max_speakers |
|
self.active_speakers = {0} |
|
|
|
def reset(self): |
|
"""Reset speaker detection state""" |
|
self.current_speaker = 0 |
|
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] |
|
self.speaker_centroids = [None] * self.max_speakers |
|
self.active_speakers = {0} |
|
|
|
def detect_speaker(self, embedding): |
|
"""Detect current speaker from embedding""" |
|
|
|
if not self.speaker_embeddings[0]: |
|
self.speaker_embeddings[0].append(embedding) |
|
self.speaker_centroids[0] = embedding.copy() |
|
return 0, 1.0 |
|
|
|
|
|
current_centroid = self.speaker_centroids[self.current_speaker] |
|
if current_centroid is not None: |
|
similarity = 1.0 - cosine(embedding, current_centroid) |
|
else: |
|
similarity = 0.0 |
|
|
|
|
|
if similarity < self.threshold: |
|
|
|
best_speaker = self.current_speaker |
|
best_similarity = similarity |
|
|
|
for speaker_id in self.active_speakers: |
|
if speaker_id == self.current_speaker: |
|
continue |
|
|
|
centroid = self.speaker_centroids[speaker_id] |
|
if centroid is not None: |
|
sim = 1.0 - cosine(embedding, centroid) |
|
if sim > best_similarity and sim > self.threshold: |
|
best_similarity = sim |
|
best_speaker = speaker_id |
|
|
|
|
|
if (best_speaker == self.current_speaker and |
|
len(self.active_speakers) < self.max_speakers): |
|
for new_id in range(self.max_speakers): |
|
if new_id not in self.active_speakers: |
|
best_speaker = new_id |
|
best_similarity = 0.0 |
|
self.active_speakers.add(new_id) |
|
break |
|
|
|
|
|
if best_speaker != self.current_speaker: |
|
self.current_speaker = best_speaker |
|
similarity = best_similarity |
|
|
|
|
|
self._update_speaker_model(self.current_speaker, embedding) |
|
return self.current_speaker, similarity |
|
|
|
def _update_speaker_model(self, speaker_id, embedding): |
|
"""Update speaker model with new embedding""" |
|
self.speaker_embeddings[speaker_id].append(embedding) |
|
|
|
|
|
if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE: |
|
self.speaker_embeddings[speaker_id] = \ |
|
self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:] |
|
|
|
|
|
if self.speaker_embeddings[speaker_id]: |
|
self.speaker_centroids[speaker_id] = np.mean( |
|
self.speaker_embeddings[speaker_id], axis=0 |
|
) |
|
|
|
class AudioProcessor: |
|
"""Handles audio processing and transcription""" |
|
|
|
def __init__(self): |
|
self.encoder = SpeakerEncoder() |
|
self.detector = SpeakerDetector() |
|
|
|
|
|
try: |
|
self.transcriber = pipeline( |
|
"automatic-speech-recognition", |
|
model="openai/whisper-base", |
|
chunk_length_s=30, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
print("Whisper model loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading Whisper model: {e}") |
|
self.transcriber = None |
|
|
|
def process_audio_file(self, audio_file): |
|
"""Process uploaded audio file""" |
|
if audio_file is None: |
|
return "Please upload an audio file.", "" |
|
|
|
try: |
|
|
|
self.detector.reset() |
|
|
|
|
|
waveform, sample_rate = torchaudio.load(audio_file) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
|
|
|
if sample_rate != Config.SAMPLE_RATE: |
|
resampler = torchaudio.transforms.Resample(sample_rate, Config.SAMPLE_RATE) |
|
waveform = resampler(waveform) |
|
|
|
|
|
audio_data = waveform.squeeze().numpy() |
|
|
|
|
|
if self.transcriber: |
|
transcription_result = self.transcriber(audio_file) |
|
full_transcription = transcription_result['text'] |
|
else: |
|
full_transcription = "Transcription service unavailable" |
|
|
|
|
|
chunk_duration = 3.0 |
|
chunk_samples = int(chunk_duration * Config.SAMPLE_RATE) |
|
results = [] |
|
|
|
for i in range(0, len(audio_data), chunk_samples // 2): |
|
chunk = audio_data[i:i + chunk_samples] |
|
|
|
if len(chunk) < Config.SAMPLE_RATE: |
|
continue |
|
|
|
|
|
embedding = self.encoder.extract_embedding(chunk) |
|
speaker_id, similarity = self.detector.detect_speaker(embedding) |
|
|
|
|
|
start_time = i / Config.SAMPLE_RATE |
|
end_time = (i + len(chunk)) / Config.SAMPLE_RATE |
|
|
|
|
|
if self.transcriber and len(chunk) > Config.SAMPLE_RATE: |
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: |
|
torchaudio.save(tmp_file.name, torch.tensor(chunk).unsqueeze(0), Config.SAMPLE_RATE) |
|
chunk_result = self.transcriber(tmp_file.name) |
|
chunk_text = chunk_result['text'].strip() |
|
os.unlink(tmp_file.name) |
|
else: |
|
chunk_text = "" |
|
|
|
if chunk_text: |
|
results.append({ |
|
'speaker_id': speaker_id, |
|
'start_time': start_time, |
|
'end_time': end_time, |
|
'text': chunk_text, |
|
'similarity': similarity |
|
}) |
|
|
|
|
|
formatted_output = self._format_results(results) |
|
return formatted_output, full_transcription |
|
|
|
except Exception as e: |
|
return f"Error processing audio: {str(e)}", "" |
|
|
|
def _format_results(self, results): |
|
"""Format results with speaker colors""" |
|
if not results: |
|
return "No speech detected in the audio file." |
|
|
|
formatted_lines = [] |
|
formatted_lines.append("๐ค **Speaker Diarization Results**\n") |
|
|
|
for result in results: |
|
speaker_id = result['speaker_id'] |
|
start_time = result['start_time'] |
|
end_time = result['end_time'] |
|
text = result['text'] |
|
similarity = result['similarity'] |
|
|
|
color = SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)] |
|
|
|
|
|
start_min, start_sec = divmod(int(start_time), 60) |
|
end_min, end_sec = divmod(int(end_time), 60) |
|
timestamp = f"[{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}]" |
|
|
|
|
|
formatted_lines.append( |
|
f'<div style="margin-bottom: 10px; padding: 8px; border-left: 4px solid {color}; background-color: {color}20;">' |
|
f'<strong style="color: {color};">Speaker {speaker_id + 1}</strong> ' |
|
f'<span style="color: #666; font-size: 0.9em;">{timestamp}</span><br>' |
|
f'<span style="color: #333;">{text}</span>' |
|
f'</div>' |
|
) |
|
|
|
return "".join(formatted_lines) |
|
|
|
|
|
processor = AudioProcessor() |
|
|
|
def process_audio(audio_file, sensitivity): |
|
"""Process audio file with speaker detection""" |
|
if audio_file is None: |
|
return "Please upload an audio file.", "" |
|
|
|
|
|
processor.detector.threshold = sensitivity |
|
|
|
|
|
diarized_output, full_transcription = processor.process_audio_file(audio_file) |
|
|
|
return diarized_output, full_transcription |
|
|
|
|
|
def create_interface(): |
|
"""Create Gradio interface""" |
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft(), |
|
title="Speaker Diarization & Transcription", |
|
css=""" |
|
.gradio-container { |
|
max-width: 1200px !important; |
|
} |
|
.speaker-output { |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
} |
|
""" |
|
) as demo: |
|
|
|
gr.Markdown( |
|
""" |
|
# ๐๏ธ Speaker Diarization & Transcription |
|
|
|
Upload an audio file to automatically detect different speakers and transcribe their speech. |
|
The system will identify speaker changes and display each speaker's text in different colors. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
audio_input = gr.Audio( |
|
label="Upload Audio File", |
|
type="filepath", |
|
sources=["upload", "microphone"] |
|
) |
|
|
|
sensitivity_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.65, |
|
step=0.05, |
|
label="Speaker Change Sensitivity", |
|
info="Lower values = more sensitive to speaker changes" |
|
) |
|
|
|
process_btn = gr.Button("๐ฏ Process Audio", variant="primary", size="lg") |
|
|
|
gr.Markdown( |
|
""" |
|
### Instructions: |
|
1. Upload an audio file (WAV, MP3, etc.) |
|
2. Adjust sensitivity if needed |
|
3. Click "Process Audio" |
|
4. View results with speaker colors |
|
|
|
### Tips: |
|
- Works best with clear speech |
|
- Supports multiple file formats |
|
- Different speakers shown in different colors |
|
- Processing may take a moment for longer files |
|
""" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Tabs(): |
|
with gr.TabItem("๐จ Speaker Diarization"): |
|
diarized_output = gr.HTML( |
|
label="Speaker Diarization Results", |
|
elem_classes=["speaker-output"] |
|
) |
|
|
|
with gr.TabItem("๐ Full Transcription"): |
|
full_transcription = gr.Textbox( |
|
label="Complete Transcription", |
|
lines=15, |
|
max_lines=20, |
|
show_copy_button=True |
|
) |
|
|
|
|
|
process_btn.click( |
|
fn=process_audio, |
|
inputs=[audio_input, sensitivity_slider], |
|
outputs=[diarized_output, full_transcription], |
|
show_progress=True |
|
) |
|
|
|
|
|
audio_input.change( |
|
fn=process_audio, |
|
inputs=[audio_input, sensitivity_slider], |
|
outputs=[diarized_output, full_transcription], |
|
show_progress=True |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
--- |
|
### About |
|
This application uses: |
|
- **MFCC features** for speaker embedding extraction |
|
- **Cosine similarity** for speaker change detection |
|
- **OpenAI Whisper** for speech-to-text transcription |
|
- **Gradio** for the web interface |
|
|
|
**Note**: This is a simplified speaker diarization system. For production use, |
|
consider more advanced speaker embedding models like speechbrain or pyannote.audio. |
|
""" |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
show_error=True |
|
) |
|
|