voidful commited on
Commit
701891b
·
verified ·
1 Parent(s): ddf58eb

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +56 -43
processing_gemma3_omni.py CHANGED
@@ -28,42 +28,61 @@ DEFAULT_MAX_LENGTH = 16384
28
 
29
  logger = logging.get_logger(__name__)
30
 
31
- def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
32
- fmax: Optional[float] = None) -> np.ndarray:
33
- """Create Mel filterbank for audio processing."""
34
- fmax = fmax or sampling_rate / 2.0
35
-
36
- def hz_to_mel(f: float) -> float: # Slaney scale from Snippet B
37
- return 1127.0 * math.log(1 + f / 700.0)
38
-
39
- if fmin >= fmax:
40
- raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
41
-
42
- mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
43
- freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of Slaney hz_to_mel
44
-
45
- freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
46
- bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
47
- bins = np.clip(bins, 0, n_fft // 2)
48
-
49
- filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
50
- for m_idx in range(n_mels):
51
- left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
52
-
53
- if center > left:
54
- filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
55
- if right > center:
56
- filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
57
-
58
- if left <= center <= right:
59
- if filterbank.shape[1] > center:
60
- if (center > left and filterbank[m_idx, center] < 1.0 and center < right) or \
61
- (left == center and center < right) or \
62
- (right == center and left < center):
63
- filterbank[m_idx, center] = 1.0
64
- elif left == center and right == center :
65
- filterbank[m_idx, center] = 1.0
66
- return filterbank
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  # --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
@@ -106,13 +125,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): # MODIFIED CLASS N
106
  "This might lead to inconsistencies if the input audio is not resampled to 16000 Hz by this extractor."
107
  )
108
 
109
- self._mel = create_mel_filterbank(
110
- sampling_rate=16000, # Phi4M Mel params are for 16kHz.
111
- n_fft=512,
112
- n_mels=_feature_size, # Use the effective feature_size (should be 80)
113
- fmin=0.0,
114
- fmax=7690.0
115
- ).T
116
  self._hamming400 = np.hamming(400)
117
  self._hamming200 = np.hamming(200)
118
 
 
28
 
29
  logger = logging.get_logger(__name__)
30
 
31
+ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
32
+ """Create a Mel filter-bank the same as SpeechLib FbankFC.
33
+ Args:
34
+ sample_rate (int): Sample rate in Hz. number > 0 [scalar]
35
+ n_fft (int): FFT size. int > 0 [scalar]
36
+ n_mel (int): Mel filter size. int > 0 [scalar]
37
+ fmin (float): lowest frequency (in Hz). If None use 0.0.
38
+ float >= 0 [scalar]
39
+ fmax: highest frequency (in Hz). If None use sample_rate / 2.
40
+ float >= 0 [scalar]
41
+ Returns
42
+ out (numpy.ndarray): Mel transform matrix
43
+ [shape=(n_mels, 1 + n_fft/2)]
44
+ """
45
+
46
+ bank_width = int(n_fft // 2 + 1)
47
+ if fmax is None:
48
+ fmax = sample_rate / 2
49
+ if fmin is None:
50
+ fmin = 0
51
+ assert fmin >= 0, "fmin cannot be negtive"
52
+ assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
53
+
54
+ def mel(f):
55
+ return 1127.0 * np.log(1.0 + f / 700.0)
56
+
57
+ def bin2mel(fft_bin):
58
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
59
+
60
+ def f2bin(f):
61
+ return int((f * n_fft / sample_rate) + 0.5)
62
+
63
+ # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
64
+ klo = f2bin(fmin) + 1
65
+ khi = f2bin(fmax)
66
+
67
+ khi = max(khi, klo)
68
+
69
+ # Spec 2: SpeechLib uses trianges in Mel space
70
+ mlo = mel(fmin)
71
+ mhi = mel(fmax)
72
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
73
+ ms = (mhi - mlo) / (n_mels + 1)
74
+
75
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
76
+ for m in range(0, n_mels):
77
+ left = m_centers[m]
78
+ center = m_centers[m + 1]
79
+ right = m_centers[m + 2]
80
+ for fft_bin in range(klo, khi):
81
+ mbin = bin2mel(fft_bin)
82
+ if left < mbin < right:
83
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
84
+
85
+ return matrix
86
 
87
 
88
  # --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
 
125
  "This might lead to inconsistencies if the input audio is not resampled to 16000 Hz by this extractor."
126
  )
127
 
128
+ self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
 
 
 
 
 
 
129
  self._hamming400 = np.hamming(400)
130
  self._hamming200 = np.hamming(200)
131