TomRoma commited on
Commit
73b6e10
·
1 Parent(s): fe970e3

Enhance speaker identification functionality and add comprehensive tests for audio inputs, updated requirments.txt

Browse files
requirements.txt CHANGED
@@ -20,4 +20,6 @@ scipy>=1.7.0
20
  matplotlib>=3.3.0
21
  seaborn>=0.11.0
22
 
23
- # install ffmpeg
 
 
 
20
  matplotlib>=3.3.0
21
  seaborn>=0.11.0
22
 
23
+ # install ffmpeg
24
+ librosa>=0.8.0
25
+ transformers>=4.0.0
speaker/speaker_identification.py CHANGED
@@ -1,16 +1,118 @@
1
- from typing import List
 
2
 
3
- # Assigns speaker IDs for a list of audio segments.
4
-
5
- # Args:
6
- # audio_list (List): List of audio (list of file path or list of nparray, assume sampling rate = 16000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Returns:
9
- # List[str]: List of speaker IDs corresponding to each audio segment
10
- def assign_speaker_for_audio_list(audio_list: List) -> List[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- speaker_ids = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
14
  return speaker_ids
15
 
16
 
 
1
+ from typing import List, Union, Optional
2
+ import os
3
 
4
+ import numpy as np
5
+ import librosa
6
+ from transformers import pipeline
7
+
8
+ # Default sample rate for audio processing
9
+ DEFAULT_SAMPLE_RATE = 16000
10
+
11
+ # Singleton pattern to avoid loading the model multiple times
12
+ _PREDICTOR_INSTANCE = None
13
+
14
+ def get_predictor():
15
+ """
16
+ Get or create the singleton predictor instance.
17
+ Returns:
18
+ Predictor: A shared instance of the Predictor class.
19
+ """
20
+ global _PREDICTOR_INSTANCE
21
+ if _PREDICTOR_INSTANCE is None:
22
+ _PREDICTOR_INSTANCE = Predictor()
23
+ return _PREDICTOR_INSTANCE
24
+ class Predictor:
25
+ def __init__(self, model_path: Optional[str] = None):
26
+ """
27
+ Initialize the predictor with a pre-trained model.
28
+
29
+ Args:
30
+ model_path: Optional path to a local model. If None, uses the default HuggingFace model.
31
+ """
32
+ # Load Hugging Face audio-classification pipeline
33
+ self.model = pipeline("audio-classification", model="bookbot/wav2vec2-adult-child-cls")
34
+
35
+ def preprocess(self, input_item: Union[str, np.ndarray]) -> np.ndarray:
36
+ """
37
+ Preprocess an input item (either file path or numpy array).
38
+
39
+ Args:
40
+ input_item: Either a file path string or a numpy array of audio data.
41
+
42
+ Returns:
43
+ np.ndarray: Processed audio data as a numpy array.
44
+
45
+ Raises:
46
+ ValueError: If input type is unsupported.
47
+ """
48
+ if isinstance(input_item, str):
49
+ # Load audio file to numpy array
50
+ audio, _ = librosa.load(input_item, sr=DEFAULT_SAMPLE_RATE)
51
+ return audio
52
+ elif isinstance(input_item, np.ndarray):
53
+ return input_item
54
+ else:
55
+ raise ValueError(f"Unsupported input type: {type(input_item)}")
56
+
57
+ def predict(self, input_list: List[Union[str, np.ndarray]]) -> List[int]:
58
+ """
59
+ Predict speaker type (child=0, adult=1) for a list of audio inputs.
60
 
61
+ Args:
62
+ input_list: List of inputs, either file paths or numpy arrays.
63
+
64
+ Returns:
65
+ List[int]: List of predictions (0=child, 1=adult, -1=unknown).
66
+ """
67
+ # Preprocess all inputs first
68
+ processed = [self.preprocess(item) for item in input_list]
69
+
70
+ # Batch inference
71
+ preds = self.model(processed, sampling_rate=DEFAULT_SAMPLE_RATE)
72
+
73
+ # Map label to 0 (child) or 1 (adult)
74
+ label_map = {
75
+ "child": 0,
76
+ "adult": 1
77
+ }
78
+
79
+ results = []
80
+ for pred in preds:
81
+ # pred can be a list of dicts (top-k), take the top prediction
82
+ if isinstance(pred, list):
83
+ label = pred[0]["label"]
84
+ else:
85
+ label = pred["label"]
86
+ results.append(label_map.get(label.lower(), -1)) # -1 for unknown label
87
+ return results
88
+
89
+ # Usage:
90
+ # predictor = Predictor("path/to/model")
91
+ # predictions = predictor.predict(list_of_inputs)
92
 
93
+ def assign_speaker_for_audio_list(audio_list: List[Union[str, np.ndarray]]) -> List[str]:
94
+ """
95
+ Assigns speaker IDs for a list of audio segments.
96
+
97
+ Args:
98
+ audio_list: List of audio inputs (either file paths or numpy arrays,
99
+ assumed to have sampling rate = 16000).
100
+
101
+ Returns:
102
+ List[str]: List of speaker IDs corresponding to each audio segment.
103
+ "Speaker_id_0" for child, "Speaker_id_1" for adult.
104
+ """
105
+ if not audio_list:
106
+ return []
107
+
108
+ # Use singleton predictor to avoid reloading model
109
+ predictor = get_predictor()
110
+
111
+ # Get list of 0 (child) or 1 (adult)
112
+ numeric_labels = predictor.predict(audio_list)
113
 
114
+ # Map to Speaker_id_0 and Speaker_id_1, preserving order
115
+ speaker_ids = [f"Speaker_id_{label}" if label in (0,1) else "Unknown" for label in numeric_labels]
116
  return speaker_ids
117
 
118
 
test_eval_speaker_identification.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import librosa
4
+ from speaker.speaker_identification import assign_speaker_for_audio_list
5
+
6
+ # Define constants
7
+ TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Test_data_for_clas_Idef')
8
+ AUDIO_FILES_DIR = os.path.join(TEST_DATA_DIR, 'enni_audio_files')
9
+ NUMPY_FILES_DIR = os.path.join(TEST_DATA_DIR, 'enni_testset_numpy_minimal')
10
+ FILEPATHS_DIR = os.path.join(TEST_DATA_DIR, 'enni_testset_filepaths_minimal')
11
+
12
+ def generate_fake_audio_test_set(num_samples=10, length=16000, random_seed=42):
13
+ """
14
+ Generate a synthetic test set of fake audio signals (numpy arrays).
15
+ Args:
16
+ num_samples (int): Number of audio samples.
17
+ length (int): Length of each audio sample (e.g., 1 second at 16kHz).
18
+ random_seed (int): Seed for reproducibility.
19
+ Returns:
20
+ List[np.ndarray]: List of fake audio signals.
21
+ """
22
+ np.random.seed(random_seed)
23
+ return [np.random.randn(length) for _ in range(num_samples)]
24
+
25
+ def test_file_paths():
26
+ """Test with all real audio files from the dataset"""
27
+ # Get file paths using the constant
28
+ audio_dir = AUDIO_FILES_DIR
29
+
30
+ # Get all child and adult files
31
+ child_files = [
32
+ os.path.join(audio_dir, file)
33
+ for file in os.listdir(audio_dir)
34
+ if file.startswith('child_') and file.endswith('.wav')
35
+ ] # Use all child files
36
+
37
+ adult_files = [
38
+ os.path.join(audio_dir, file)
39
+ for file in os.listdir(audio_dir)
40
+ if file.startswith('adult_') and file.endswith('.wav')
41
+ ] # Use all adult files
42
+
43
+ # Create list with known order
44
+ audio_list = child_files + adult_files
45
+
46
+ # Get speaker IDs
47
+ speaker_ids = assign_speaker_for_audio_list(audio_list)
48
+
49
+ # Print results
50
+ print("\n--- Testing with file paths ---")
51
+ print(f"Testing {len(audio_list)} audio files: {len(child_files)} child files and {len(adult_files)} adult files")
52
+
53
+ # Count correct predictions
54
+ correct = 0
55
+ for i, (file, speaker_id) in enumerate(zip(audio_list, speaker_ids)):
56
+ expected = "Speaker_id_0" if "child_" in file else "Speaker_id_1"
57
+ is_correct = speaker_id == expected
58
+ correct += 1 if is_correct else 0
59
+
60
+ # Print only the first 5 examples to avoid cluttering the output
61
+ if i < 5:
62
+ print(f"{i+1}. {os.path.basename(file)}: {speaker_id} (Expected: {expected}) {'✓' if is_correct else '✗'}")
63
+
64
+ # Print accuracy
65
+ accuracy = correct / len(audio_list) * 100 if audio_list else 0
66
+ print(f"Accuracy: {correct}/{len(audio_list)} ({accuracy:.2f}%)")
67
+
68
+ def test_numpy_arrays():
69
+ """Test with NumPy arrays by loading all audio files"""
70
+ # Get file paths using the constant
71
+ audio_dir = AUDIO_FILES_DIR
72
+
73
+ # Load all child and adult files as arrays
74
+ child_files = [
75
+ os.path.join(audio_dir, file)
76
+ for file in os.listdir(audio_dir)
77
+ if file.startswith('child_') and file.endswith('.wav')
78
+ ]
79
+
80
+ adult_files = [
81
+ os.path.join(audio_dir, file)
82
+ for file in os.listdir(audio_dir)
83
+ if file.startswith('adult_') and file.endswith('.wav')
84
+ ]
85
+
86
+ # Load as arrays
87
+ child_arrays = [librosa.load(f, sr=16000)[0] for f in child_files]
88
+ adult_arrays = [librosa.load(f, sr=16000)[0] for f in adult_files]
89
+
90
+ # Create list with known order
91
+ audio_list = child_arrays + adult_arrays
92
+ filenames = [os.path.basename(f) for f in child_files + adult_files]
93
+
94
+ # Get speaker IDs
95
+ speaker_ids = assign_speaker_for_audio_list(audio_list)
96
+
97
+ # Print results
98
+ print("\n--- Testing with NumPy arrays ---")
99
+ print(f"Testing {len(audio_list)} audio arrays: {len(child_arrays)} child arrays and {len(adult_arrays)} adult arrays")
100
+
101
+ # Count correct predictions
102
+ correct = 0
103
+ for i, (filename, speaker_id) in enumerate(zip(filenames, speaker_ids)):
104
+ expected = "Speaker_id_0" if "child_" in filename else "Speaker_id_1"
105
+ is_correct = speaker_id == expected
106
+ correct += 1 if is_correct else 0
107
+
108
+ # Print only the first 5 examples to avoid cluttering the output
109
+ if i < 5:
110
+ print(f"{i+1}. {filename} (as array): {speaker_id} (Expected: {expected}) {'✓' if is_correct else '✗'}")
111
+
112
+ # Print accuracy
113
+ accuracy = correct / len(audio_list) * 100 if audio_list else 0
114
+ print(f"Accuracy: {correct}/{len(audio_list)} ({accuracy:.2f}%)")
115
+
116
+ if __name__ == "__main__":
117
+ # Test with synthetic data
118
+ print("--- Testing with synthetic data ---")
119
+ audio_list = generate_fake_audio_test_set(num_samples=5)
120
+ speaker_ids = assign_speaker_for_audio_list(audio_list)
121
+ print(f"Synthetic data predictions: {speaker_ids}")
122
+
123
+ # Test with real files
124
+ try:
125
+ test_file_paths()
126
+ except Exception as e:
127
+ print(f"Error testing file paths: {e}")
128
+
129
+ # Test with NumPy arrays
130
+ try:
131
+ test_numpy_arrays()
132
+ except Exception as e:
133
+ print(f"Error testing NumPy arrays: {e}")