SATEv1.5 / speaker /speaker_identification.py
TomRoma's picture
Enhance speaker identification functionality and add comprehensive tests for audio inputs, updated requirments.txt
73b6e10
raw
history blame
4.03 kB
from typing import List, Union, Optional
import os
import numpy as np
import librosa
from transformers import pipeline
# Default sample rate for audio processing
DEFAULT_SAMPLE_RATE = 16000
# Singleton pattern to avoid loading the model multiple times
_PREDICTOR_INSTANCE = None
def get_predictor():
"""
Get or create the singleton predictor instance.
Returns:
Predictor: A shared instance of the Predictor class.
"""
global _PREDICTOR_INSTANCE
if _PREDICTOR_INSTANCE is None:
_PREDICTOR_INSTANCE = Predictor()
return _PREDICTOR_INSTANCE
class Predictor:
def __init__(self, model_path: Optional[str] = None):
"""
Initialize the predictor with a pre-trained model.
Args:
model_path: Optional path to a local model. If None, uses the default HuggingFace model.
"""
# Load Hugging Face audio-classification pipeline
self.model = pipeline("audio-classification", model="bookbot/wav2vec2-adult-child-cls")
def preprocess(self, input_item: Union[str, np.ndarray]) -> np.ndarray:
"""
Preprocess an input item (either file path or numpy array).
Args:
input_item: Either a file path string or a numpy array of audio data.
Returns:
np.ndarray: Processed audio data as a numpy array.
Raises:
ValueError: If input type is unsupported.
"""
if isinstance(input_item, str):
# Load audio file to numpy array
audio, _ = librosa.load(input_item, sr=DEFAULT_SAMPLE_RATE)
return audio
elif isinstance(input_item, np.ndarray):
return input_item
else:
raise ValueError(f"Unsupported input type: {type(input_item)}")
def predict(self, input_list: List[Union[str, np.ndarray]]) -> List[int]:
"""
Predict speaker type (child=0, adult=1) for a list of audio inputs.
Args:
input_list: List of inputs, either file paths or numpy arrays.
Returns:
List[int]: List of predictions (0=child, 1=adult, -1=unknown).
"""
# Preprocess all inputs first
processed = [self.preprocess(item) for item in input_list]
# Batch inference
preds = self.model(processed, sampling_rate=DEFAULT_SAMPLE_RATE)
# Map label to 0 (child) or 1 (adult)
label_map = {
"child": 0,
"adult": 1
}
results = []
for pred in preds:
# pred can be a list of dicts (top-k), take the top prediction
if isinstance(pred, list):
label = pred[0]["label"]
else:
label = pred["label"]
results.append(label_map.get(label.lower(), -1)) # -1 for unknown label
return results
# Usage:
# predictor = Predictor("path/to/model")
# predictions = predictor.predict(list_of_inputs)
def assign_speaker_for_audio_list(audio_list: List[Union[str, np.ndarray]]) -> List[str]:
"""
Assigns speaker IDs for a list of audio segments.
Args:
audio_list: List of audio inputs (either file paths or numpy arrays,
assumed to have sampling rate = 16000).
Returns:
List[str]: List of speaker IDs corresponding to each audio segment.
"Speaker_id_0" for child, "Speaker_id_1" for adult.
"""
if not audio_list:
return []
# Use singleton predictor to avoid reloading model
predictor = get_predictor()
# Get list of 0 (child) or 1 (adult)
numeric_labels = predictor.predict(audio_list)
# Map to Speaker_id_0 and Speaker_id_1, preserving order
speaker_ids = [f"Speaker_id_{label}" if label in (0,1) else "Unknown" for label in numeric_labels]
return speaker_ids
# you don't have to implement this function
def assign_speaker(session_id: str):
return