|
from typing import List, Union, Optional |
|
import os |
|
|
|
import numpy as np |
|
import librosa |
|
from transformers import pipeline |
|
|
|
|
|
DEFAULT_SAMPLE_RATE = 16000 |
|
|
|
|
|
_PREDICTOR_INSTANCE = None |
|
|
|
def get_predictor(): |
|
""" |
|
Get or create the singleton predictor instance. |
|
Returns: |
|
Predictor: A shared instance of the Predictor class. |
|
""" |
|
global _PREDICTOR_INSTANCE |
|
if _PREDICTOR_INSTANCE is None: |
|
_PREDICTOR_INSTANCE = Predictor() |
|
return _PREDICTOR_INSTANCE |
|
class Predictor: |
|
def __init__(self, model_path: Optional[str] = None): |
|
""" |
|
Initialize the predictor with a pre-trained model. |
|
|
|
Args: |
|
model_path: Optional path to a local model. If None, uses the default HuggingFace model. |
|
""" |
|
|
|
self.model = pipeline("audio-classification", model="bookbot/wav2vec2-adult-child-cls") |
|
|
|
def preprocess(self, input_item: Union[str, np.ndarray]) -> np.ndarray: |
|
""" |
|
Preprocess an input item (either file path or numpy array). |
|
|
|
Args: |
|
input_item: Either a file path string or a numpy array of audio data. |
|
|
|
Returns: |
|
np.ndarray: Processed audio data as a numpy array. |
|
|
|
Raises: |
|
ValueError: If input type is unsupported. |
|
""" |
|
if isinstance(input_item, str): |
|
|
|
audio, _ = librosa.load(input_item, sr=DEFAULT_SAMPLE_RATE) |
|
return audio |
|
elif isinstance(input_item, np.ndarray): |
|
return input_item |
|
else: |
|
raise ValueError(f"Unsupported input type: {type(input_item)}") |
|
|
|
def predict(self, input_list: List[Union[str, np.ndarray]]) -> List[int]: |
|
""" |
|
Predict speaker type (child=0, adult=1) for a list of audio inputs. |
|
|
|
Args: |
|
input_list: List of inputs, either file paths or numpy arrays. |
|
|
|
Returns: |
|
List[int]: List of predictions (0=child, 1=adult, -1=unknown). |
|
""" |
|
|
|
processed = [self.preprocess(item) for item in input_list] |
|
|
|
|
|
preds = self.model(processed, sampling_rate=DEFAULT_SAMPLE_RATE) |
|
|
|
|
|
label_map = { |
|
"child": 0, |
|
"adult": 1 |
|
} |
|
|
|
results = [] |
|
for pred in preds: |
|
|
|
if isinstance(pred, list): |
|
label = pred[0]["label"] |
|
else: |
|
label = pred["label"] |
|
results.append(label_map.get(label.lower(), -1)) |
|
return results |
|
|
|
|
|
|
|
|
|
|
|
def assign_speaker_for_audio_list(audio_list: List[Union[str, np.ndarray]]) -> List[str]: |
|
""" |
|
Assigns speaker IDs for a list of audio segments. |
|
|
|
Args: |
|
audio_list: List of audio inputs (either file paths or numpy arrays, |
|
assumed to have sampling rate = 16000). |
|
|
|
Returns: |
|
List[str]: List of speaker IDs corresponding to each audio segment. |
|
"Speaker_id_0" for child, "Speaker_id_1" for adult. |
|
""" |
|
if not audio_list: |
|
return [] |
|
|
|
|
|
predictor = get_predictor() |
|
|
|
|
|
numeric_labels = predictor.predict(audio_list) |
|
|
|
|
|
speaker_ids = [f"Speaker_id_{label}" if label in (0,1) else "Unknown" for label in numeric_labels] |
|
return speaker_ids |
|
|
|
|
|
|
|
def assign_speaker(session_id: str): |
|
|
|
return |