|
""" |
|
Speaker Diarization Module for Multilingual Audio Intelligence System |
|
|
|
This module implements state-of-the-art speaker diarization using pyannote.audio. |
|
It segments audio to identify "who spoke when" with high accuracy and language-agnostic |
|
speaker separation capabilities as required by PS-6. |
|
|
|
Key Features: |
|
- SOTA speaker diarization using pyannote.audio |
|
- Language-agnostic voice characteristic analysis |
|
- Integrated Voice Activity Detection (VAD) |
|
- Automatic speaker count detection |
|
- CPU and GPU optimization support |
|
- Robust error handling and logging |
|
|
|
Model: pyannote/speaker-diarization-3.1 |
|
Dependencies: pyannote.audio, torch, transformers |
|
""" |
|
|
|
import os |
|
import logging |
|
import warnings |
|
import numpy as np |
|
import torch |
|
from typing import List, Tuple, Dict, Optional, Union |
|
import tempfile |
|
from dataclasses import dataclass |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
try: |
|
from pyannote.audio import Pipeline |
|
from pyannote.core import Annotation, Segment |
|
PYANNOTE_AVAILABLE = True |
|
except ImportError: |
|
PYANNOTE_AVAILABLE = False |
|
logging.warning("pyannote.audio not available. Install with: pip install pyannote.audio") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
@dataclass |
|
class SpeakerSegment: |
|
""" |
|
Data class representing a single speaker segment. |
|
|
|
Attributes: |
|
start_time (float): Segment start time in seconds |
|
end_time (float): Segment end time in seconds |
|
speaker_id (str): Unique speaker identifier (e.g., "SPEAKER_00") |
|
confidence (float): Confidence score of the diarization (if available) |
|
""" |
|
start_time: float |
|
end_time: float |
|
speaker_id: str |
|
confidence: float = 1.0 |
|
|
|
@property |
|
def duration(self) -> float: |
|
"""Duration of the segment in seconds.""" |
|
return self.end_time - self.start_time |
|
|
|
def to_dict(self) -> dict: |
|
"""Convert to dictionary for JSON serialization.""" |
|
return { |
|
'start_time': self.start_time, |
|
'end_time': self.end_time, |
|
'speaker_id': self.speaker_id, |
|
'duration': self.duration, |
|
'confidence': self.confidence |
|
} |
|
|
|
|
|
class SpeakerDiarizer: |
|
""" |
|
State-of-the-art speaker diarization using pyannote.audio. |
|
|
|
This class provides language-agnostic speaker diarization capabilities, |
|
focusing on acoustic voice characteristics rather than linguistic content. |
|
""" |
|
|
|
def __init__(self, |
|
model_name: str = "pyannote/speaker-diarization-3.1", |
|
hf_token: Optional[str] = None, |
|
device: Optional[str] = None, |
|
min_speakers: Optional[int] = None, |
|
max_speakers: Optional[int] = None): |
|
""" |
|
Initialize the Speaker Diarizer. |
|
|
|
Args: |
|
model_name (str): Hugging Face model name for diarization |
|
hf_token (str, optional): Hugging Face token for gated models |
|
device (str, optional): Device to run on ('cpu', 'cuda', 'auto') |
|
min_speakers (int, optional): Minimum number of speakers to detect |
|
max_speakers (int, optional): Maximum number of speakers to detect |
|
""" |
|
self.model_name = model_name |
|
self.hf_token = hf_token or os.getenv('HUGGINGFACE_TOKEN') |
|
self.min_speakers = min_speakers |
|
self.max_speakers = max_speakers |
|
|
|
|
|
if device == 'auto' or device is None: |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
else: |
|
self.device = torch.device(device) |
|
|
|
logger.info(f"Initializing SpeakerDiarizer on {self.device}") |
|
|
|
|
|
self.pipeline = None |
|
self._load_pipeline() |
|
|
|
def _load_pipeline(self): |
|
"""Load the pyannote.audio diarization pipeline.""" |
|
if not PYANNOTE_AVAILABLE: |
|
raise ImportError( |
|
"pyannote.audio is required for speaker diarization. " |
|
"Install with: pip install pyannote.audio" |
|
) |
|
|
|
try: |
|
|
|
logger.info(f"Loading {self.model_name}...") |
|
|
|
if self.hf_token: |
|
self.pipeline = Pipeline.from_pretrained( |
|
self.model_name, |
|
use_auth_token=self.hf_token |
|
) |
|
else: |
|
|
|
try: |
|
self.pipeline = Pipeline.from_pretrained(self.model_name) |
|
except Exception as e: |
|
logger.error( |
|
f"Failed to load {self.model_name}. " |
|
"This model may be gated and require a Hugging Face token. " |
|
f"Set HUGGINGFACE_TOKEN environment variable. Error: {e}" |
|
) |
|
raise |
|
|
|
|
|
self.pipeline = self.pipeline.to(self.device) |
|
|
|
|
|
if self.min_speakers is not None or self.max_speakers is not None: |
|
self.pipeline.instantiate({ |
|
"clustering": { |
|
"min_cluster_size": self.min_speakers or 1, |
|
"max_num_speakers": self.max_speakers or 20 |
|
} |
|
}) |
|
|
|
logger.info(f"Successfully loaded {self.model_name} on {self.device}") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load diarization pipeline: {e}") |
|
raise |
|
|
|
def diarize(self, |
|
audio_input: Union[str, np.ndarray], |
|
sample_rate: int = 16000) -> List[SpeakerSegment]: |
|
""" |
|
Perform speaker diarization on audio input. |
|
|
|
Args: |
|
audio_input: Audio file path or numpy array |
|
sample_rate: Sample rate if audio_input is numpy array |
|
|
|
Returns: |
|
List[SpeakerSegment]: List of speaker segments with timestamps |
|
|
|
Raises: |
|
ValueError: If input is invalid |
|
Exception: For diarization errors |
|
""" |
|
if self.pipeline is None: |
|
raise RuntimeError("Pipeline not loaded. Call _load_pipeline() first.") |
|
|
|
try: |
|
|
|
audio_file = self._prepare_audio_input(audio_input, sample_rate) |
|
|
|
logger.info("Starting speaker diarization...") |
|
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None |
|
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None |
|
|
|
if start_time: |
|
start_time.record() |
|
|
|
|
|
diarization_result = self.pipeline(audio_file) |
|
|
|
if end_time and start_time: |
|
end_time.record() |
|
torch.cuda.synchronize() |
|
processing_time = start_time.elapsed_time(end_time) / 1000.0 |
|
logger.info(f"Diarization completed in {processing_time:.2f}s") |
|
|
|
|
|
segments = self._parse_diarization_result(diarization_result) |
|
|
|
|
|
num_speakers = len(set(seg.speaker_id for seg in segments)) |
|
total_speech_time = sum(seg.duration for seg in segments) |
|
|
|
logger.info(f"Detected {num_speakers} speakers, {len(segments)} segments, " |
|
f"{total_speech_time:.1f}s total speech") |
|
|
|
return segments |
|
|
|
except Exception as e: |
|
logger.error(f"Diarization failed: {str(e)}") |
|
raise |
|
|
|
finally: |
|
|
|
if isinstance(audio_input, np.ndarray): |
|
try: |
|
if hasattr(audio_file, 'name') and os.path.exists(audio_file.name): |
|
os.unlink(audio_file.name) |
|
except Exception: |
|
pass |
|
|
|
def _prepare_audio_input(self, |
|
audio_input: Union[str, np.ndarray], |
|
sample_rate: int) -> str: |
|
""" |
|
Prepare audio input for pyannote.audio pipeline. |
|
|
|
Args: |
|
audio_input: Audio file path or numpy array |
|
sample_rate: Sample rate for numpy array input |
|
|
|
Returns: |
|
str: Path to audio file ready for pyannote |
|
""" |
|
if isinstance(audio_input, str): |
|
|
|
if not os.path.exists(audio_input): |
|
raise FileNotFoundError(f"Audio file not found: {audio_input}") |
|
return audio_input |
|
|
|
elif isinstance(audio_input, np.ndarray): |
|
|
|
return self._save_array_to_tempfile(audio_input, sample_rate) |
|
|
|
else: |
|
raise ValueError(f"Unsupported audio input type: {type(audio_input)}") |
|
|
|
def _save_array_to_tempfile(self, audio_array: np.ndarray, sample_rate: int) -> str: |
|
""" |
|
Save numpy array to temporary WAV file for pyannote processing. |
|
|
|
Args: |
|
audio_array: Audio data as numpy array |
|
sample_rate: Sample rate of the audio |
|
|
|
Returns: |
|
str: Path to temporary WAV file |
|
""" |
|
try: |
|
import soundfile as sf |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile( |
|
delete=False, |
|
suffix='.wav', |
|
prefix='diarization_' |
|
) |
|
temp_path = temp_file.name |
|
temp_file.close() |
|
|
|
|
|
if len(audio_array.shape) > 1: |
|
audio_array = audio_array.flatten() |
|
|
|
|
|
if np.max(np.abs(audio_array)) > 1.0: |
|
audio_array = audio_array / np.max(np.abs(audio_array)) |
|
|
|
|
|
sf.write(temp_path, audio_array, sample_rate) |
|
|
|
logger.debug(f"Saved audio array to temporary file: {temp_path}") |
|
return temp_path |
|
|
|
except ImportError: |
|
|
|
try: |
|
from scipy.io import wavfile |
|
|
|
temp_file = tempfile.NamedTemporaryFile( |
|
delete=False, |
|
suffix='.wav', |
|
prefix='diarization_' |
|
) |
|
temp_path = temp_file.name |
|
temp_file.close() |
|
|
|
|
|
if audio_array.dtype != np.int16: |
|
audio_array_int = (audio_array * 32767).astype(np.int16) |
|
else: |
|
audio_array_int = audio_array |
|
|
|
wavfile.write(temp_path, sample_rate, audio_array_int) |
|
|
|
logger.debug(f"Saved audio array using scipy: {temp_path}") |
|
return temp_path |
|
|
|
except ImportError: |
|
raise ImportError( |
|
"Neither soundfile nor scipy available for audio saving. " |
|
"Install with: pip install soundfile" |
|
) |
|
|
|
def _parse_diarization_result(self, diarization: Annotation) -> List[SpeakerSegment]: |
|
""" |
|
Parse pyannote diarization result into structured segments. |
|
|
|
Args: |
|
diarization: pyannote Annotation object |
|
|
|
Returns: |
|
List[SpeakerSegment]: Parsed speaker segments |
|
""" |
|
segments = [] |
|
|
|
for segment, _, speaker_label in diarization.itertracks(yield_label=True): |
|
|
|
speaker_segment = SpeakerSegment( |
|
start_time=float(segment.start), |
|
end_time=float(segment.end), |
|
speaker_id=str(speaker_label), |
|
confidence=1.0 |
|
) |
|
segments.append(speaker_segment) |
|
|
|
|
|
segments.sort(key=lambda x: x.start_time) |
|
|
|
return segments |
|
|
|
def get_speaker_statistics(self, segments: List[SpeakerSegment]) -> Dict[str, dict]: |
|
""" |
|
Generate speaker statistics from diarization results. |
|
|
|
Args: |
|
segments: List of speaker segments |
|
|
|
Returns: |
|
Dict: Speaker statistics including speaking time, turn counts, etc. |
|
""" |
|
stats = {} |
|
|
|
for segment in segments: |
|
speaker_id = segment.speaker_id |
|
|
|
if speaker_id not in stats: |
|
stats[speaker_id] = { |
|
'total_speaking_time': 0.0, |
|
'number_of_turns': 0, |
|
'average_turn_duration': 0.0, |
|
'longest_turn': 0.0, |
|
'shortest_turn': float('inf') |
|
} |
|
|
|
|
|
stats[speaker_id]['total_speaking_time'] += segment.duration |
|
stats[speaker_id]['number_of_turns'] += 1 |
|
stats[speaker_id]['longest_turn'] = max( |
|
stats[speaker_id]['longest_turn'], |
|
segment.duration |
|
) |
|
stats[speaker_id]['shortest_turn'] = min( |
|
stats[speaker_id]['shortest_turn'], |
|
segment.duration |
|
) |
|
|
|
|
|
for speaker_id, speaker_stats in stats.items(): |
|
if speaker_stats['number_of_turns'] > 0: |
|
speaker_stats['average_turn_duration'] = ( |
|
speaker_stats['total_speaking_time'] / |
|
speaker_stats['number_of_turns'] |
|
) |
|
|
|
|
|
if speaker_stats['shortest_turn'] == float('inf'): |
|
speaker_stats['shortest_turn'] = 0.0 |
|
|
|
return stats |
|
|
|
def merge_short_segments(self, |
|
segments: List[SpeakerSegment], |
|
min_duration: float = 1.0) -> List[SpeakerSegment]: |
|
""" |
|
Merge segments that are too short with adjacent segments from same speaker. |
|
|
|
Args: |
|
segments: List of speaker segments |
|
min_duration: Minimum duration for segments in seconds |
|
|
|
Returns: |
|
List[SpeakerSegment]: Processed segments with short ones merged |
|
""" |
|
if not segments: |
|
return segments |
|
|
|
merged_segments = [] |
|
current_segment = segments[0] |
|
|
|
for next_segment in segments[1:]: |
|
|
|
if (current_segment.duration < min_duration and |
|
current_segment.speaker_id == next_segment.speaker_id): |
|
|
|
|
|
current_segment.end_time = next_segment.end_time |
|
|
|
else: |
|
|
|
merged_segments.append(current_segment) |
|
current_segment = next_segment |
|
|
|
|
|
merged_segments.append(current_segment) |
|
|
|
logger.debug(f"Merged {len(segments)} segments into {len(merged_segments)}") |
|
|
|
return merged_segments |
|
|
|
def export_to_rttm(self, |
|
segments: List[SpeakerSegment], |
|
audio_filename: str = "audio") -> str: |
|
""" |
|
Export diarization results to RTTM format. |
|
|
|
RTTM (Rich Transcription Time Marked) is a standard format |
|
for speaker diarization results. |
|
|
|
Args: |
|
segments: List of speaker segments |
|
audio_filename: Name of the audio file for RTTM output |
|
|
|
Returns: |
|
str: RTTM formatted string |
|
""" |
|
rttm_lines = [] |
|
|
|
for segment in segments: |
|
|
|
rttm_line = ( |
|
f"SPEAKER {audio_filename} 1 " |
|
f"{segment.start_time:.3f} {segment.duration:.3f} " |
|
f"<NA> <NA> {segment.speaker_id} {segment.confidence:.3f}" |
|
) |
|
rttm_lines.append(rttm_line) |
|
|
|
return "\n".join(rttm_lines) |
|
|
|
def __del__(self): |
|
"""Cleanup resources when the object is destroyed.""" |
|
|
|
if hasattr(self, 'device') and self.device.type == 'cuda': |
|
try: |
|
torch.cuda.empty_cache() |
|
except Exception: |
|
pass |
|
|
|
|
|
|
|
def diarize_audio(audio_input: Union[str, np.ndarray], |
|
sample_rate: int = 16000, |
|
hf_token: Optional[str] = None, |
|
min_speakers: Optional[int] = None, |
|
max_speakers: Optional[int] = None, |
|
merge_short: bool = True, |
|
min_duration: float = 1.0) -> List[SpeakerSegment]: |
|
""" |
|
Convenience function to perform speaker diarization with default settings. |
|
|
|
Args: |
|
audio_input: Audio file path or numpy array |
|
sample_rate: Sample rate for numpy array input |
|
hf_token: Hugging Face token for gated models |
|
min_speakers: Minimum number of speakers to detect |
|
max_speakers: Maximum number of speakers to detect |
|
merge_short: Whether to merge short segments |
|
min_duration: Minimum duration for segments (if merge_short=True) |
|
|
|
Returns: |
|
List[SpeakerSegment]: Speaker diarization results |
|
|
|
Example: |
|
>>> # From file |
|
>>> segments = diarize_audio("meeting.wav") |
|
>>> |
|
>>> # From numpy array |
|
>>> import numpy as np |
|
>>> audio_data = np.random.randn(16000 * 60) # 1 minute of audio |
|
>>> segments = diarize_audio(audio_data, sample_rate=16000) |
|
>>> |
|
>>> # Print results |
|
>>> for seg in segments: |
|
>>> print(f"{seg.speaker_id}: {seg.start_time:.1f}s - {seg.end_time:.1f}s") |
|
""" |
|
|
|
diarizer = SpeakerDiarizer( |
|
hf_token=hf_token, |
|
min_speakers=min_speakers, |
|
max_speakers=max_speakers |
|
) |
|
|
|
|
|
segments = diarizer.diarize(audio_input, sample_rate) |
|
|
|
|
|
if merge_short and segments: |
|
segments = diarizer.merge_short_segments(segments, min_duration) |
|
|
|
return segments |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
import argparse |
|
import json |
|
|
|
def main(): |
|
"""Command line interface for testing speaker diarization.""" |
|
parser = argparse.ArgumentParser(description="Speaker Diarization Tool") |
|
parser.add_argument("audio_file", help="Path to audio file") |
|
parser.add_argument("--token", help="Hugging Face token") |
|
parser.add_argument("--min-speakers", type=int, help="Minimum number of speakers") |
|
parser.add_argument("--max-speakers", type=int, help="Maximum number of speakers") |
|
parser.add_argument("--output-format", choices=["json", "rttm", "text"], |
|
default="text", help="Output format") |
|
parser.add_argument("--merge-short", action="store_true", |
|
help="Merge short segments") |
|
parser.add_argument("--min-duration", type=float, default=1.0, |
|
help="Minimum segment duration for merging") |
|
parser.add_argument("--verbose", "-v", action="store_true", |
|
help="Enable verbose logging") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.verbose: |
|
logging.getLogger().setLevel(logging.DEBUG) |
|
|
|
try: |
|
|
|
print(f"Processing audio file: {args.audio_file}") |
|
|
|
segments = diarize_audio( |
|
audio_input=args.audio_file, |
|
hf_token=args.token, |
|
min_speakers=args.min_speakers, |
|
max_speakers=args.max_speakers, |
|
merge_short=args.merge_short, |
|
min_duration=args.min_duration |
|
) |
|
|
|
|
|
if args.output_format == "json": |
|
|
|
result = { |
|
"audio_file": args.audio_file, |
|
"num_speakers": len(set(seg.speaker_id for seg in segments)), |
|
"num_segments": len(segments), |
|
"total_speech_time": sum(seg.duration for seg in segments), |
|
"segments": [seg.to_dict() for seg in segments] |
|
} |
|
print(json.dumps(result, indent=2)) |
|
|
|
elif args.output_format == "rttm": |
|
|
|
diarizer = SpeakerDiarizer() |
|
rttm_content = diarizer.export_to_rttm(segments, args.audio_file) |
|
print(rttm_content) |
|
|
|
else: |
|
|
|
print(f"\n=== SPEAKER DIARIZATION RESULTS ===") |
|
print(f"Audio file: {args.audio_file}") |
|
print(f"Number of speakers: {len(set(seg.speaker_id for seg in segments))}") |
|
print(f"Number of segments: {len(segments)}") |
|
print(f"Total speech time: {sum(seg.duration for seg in segments):.1f}s") |
|
print("\n--- Segment Details ---") |
|
|
|
for i, segment in enumerate(segments, 1): |
|
print(f"#{i:2d} | {segment.speaker_id:10s} | " |
|
f"{segment.start_time:7.1f}s - {segment.end_time:7.1f}s | " |
|
f"{segment.duration:5.1f}s") |
|
|
|
|
|
diarizer = SpeakerDiarizer() |
|
stats = diarizer.get_speaker_statistics(segments) |
|
|
|
print("\n--- Speaker Statistics ---") |
|
for speaker_id, speaker_stats in stats.items(): |
|
print(f"{speaker_id}:") |
|
print(f" Speaking time: {speaker_stats['total_speaking_time']:.1f}s") |
|
print(f" Number of turns: {speaker_stats['number_of_turns']}") |
|
print(f" Average turn: {speaker_stats['average_turn_duration']:.1f}s") |
|
print(f" Longest turn: {speaker_stats['longest_turn']:.1f}s") |
|
print(f" Shortest turn: {speaker_stats['shortest_turn']:.1f}s") |
|
|
|
except Exception as e: |
|
print(f"Error: {e}", file=sys.stderr) |
|
sys.exit(1) |
|
|
|
|
|
if not PYANNOTE_AVAILABLE: |
|
print("Warning: pyannote.audio not available. Install with: pip install pyannote.audio") |
|
print("Running in demo mode...") |
|
|
|
|
|
dummy_segments = [ |
|
SpeakerSegment(0.0, 5.2, "SPEAKER_00", 0.95), |
|
SpeakerSegment(5.5, 8.3, "SPEAKER_01", 0.87), |
|
SpeakerSegment(8.8, 12.1, "SPEAKER_00", 0.92), |
|
SpeakerSegment(12.5, 15.7, "SPEAKER_01", 0.89), |
|
] |
|
|
|
print("\n=== DEMO OUTPUT (pyannote.audio not available) ===") |
|
for segment in dummy_segments: |
|
print(f"{segment.speaker_id}: {segment.start_time:.1f}s - {segment.end_time:.1f}s") |
|
else: |
|
main() |