voidful commited on
Commit
ddf58eb
·
verified ·
1 Parent(s): 3fa62c9

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +113 -85
processing_gemma3_omni.py CHANGED
@@ -16,116 +16,133 @@ from transformers.utils import TensorType, to_py_obj, logging
16
  # Constants
17
  DEFAULT_SAMPLING_RATE = 16000
18
  DEFAULT_N_FFT = 512
19
- DEFAULT_WIN_LENGTH = 400 # Matches Phi4M's 16kHz win_length for reference
20
- DEFAULT_HOP_LENGTH = 160 # Matches Phi4M's 16kHz hop_length for reference
21
  DEFAULT_N_MELS = 80
22
- DEFAULT_COMPRESSION_RATE = 4 # Generic default
23
- DEFAULT_QFORMER_RATE = 2 # Generic default
24
- DEFAULT_FEAT_STRIDE = 4 # Generic default
25
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
  DEFAULT_MAX_LENGTH = 16384
28
- # LOG_MEL_CLIP_EPSILON = 1e-5 # Original B's constant, A clips at 1.0
29
 
30
  logger = logging.get_logger(__name__)
31
 
32
-
33
- # This create_mel_filterbank function is from your original Snippet B.
34
- # It will be used by the Gemma3AudioFeatureExtractor.
35
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
36
  fmax: Optional[float] = None) -> np.ndarray:
37
  """Create Mel filterbank for audio processing."""
38
  fmax = fmax or sampling_rate / 2.0
39
 
40
- def hz_to_mel(f: float) -> float: # Slaney scale from Snippet B
41
  return 1127.0 * math.log(1 + f / 700.0)
42
 
43
  if fmin >= fmax:
44
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
45
 
46
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
47
- freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of Slaney hz_to_mel
48
 
49
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
50
  bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
51
- bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output is n_fft//2
52
 
53
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
54
  for m_idx in range(n_mels):
55
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
56
 
57
- if center > left: # Rising slope
58
  filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
59
- if right > center: # Falling slope
60
  filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
61
 
62
- # Ensure the peak at 'center' is 1.0 if it's a valid point.
63
- # This logic is from original Snippet B. Phi4M's speechlib_mel might normalize differently.
64
- if left <= center <= right: # Check if center is within the bounds of the filter
65
- if filterbank.shape[1] > center: # Check if center index is within filterbank columns
66
  if (center > left and filterbank[m_idx, center] < 1.0 and center < right) or \
67
- (left == center and center < right) or \
68
- (right == center and left < center): # Ensure it's a triangular filter with a slope
69
  filterbank[m_idx, center] = 1.0
70
- elif left == center and right == center: # Handles the case of a filter with zero width if bins are identical
71
  filterbank[m_idx, center] = 1.0
72
-
73
  return filterbank
74
 
75
 
76
  # --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
77
- class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
78
  model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
79
 
80
- def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs):
81
- feature_size = 80 # From Phi4M
82
- sampling_rate = 16000 # From Phi4M (target sampling rate)
83
- padding_value = 0.0 # From Phi4M
84
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  self.compression_rate = audio_compression_rate
87
- self.qformer_compression_rate = audio_downsample_rate # In Phi4M, audio_downsample_rate is qformer_compression_rate
88
  self.feat_stride = audio_feat_stride
89
 
90
- self._eightk_method = kwargs.get("eightk_method", "fillzero") # 'fillzero' or 'resample'
 
 
 
 
 
 
 
 
 
 
91
 
92
- # Using the provided create_mel_filterbank (Slaney scale)
93
- # Parameters for Mel filterbank match Phi4M's speechlib_mel call
94
  self._mel = create_mel_filterbank(
95
- sampling_rate=16000, # Target sampling rate
96
- n_fft=512, # n_fft for 16kHz audio in Phi4M
97
- n_mels=80, # feature_size
98
- fmin=0.0, # Phi4M's fmin is None, typically defaults to 0
99
- fmax=7690.0 # Specific fmax from Phi4M
100
  ).T
101
- self._hamming400 = np.hamming(400) # for 16k audio, from Phi4M
102
- self._hamming200 = np.hamming(200) # for 8k audio, from Phi4M
103
 
104
  def __call__(
105
  self,
106
- audios: List[Union[AudioInput, Tuple[np.ndarray, int]]], # More specific type hint
107
  return_tensors: Optional[Union[str, TensorType]] = None,
 
108
  ):
109
  returned_input_audio_embeds = []
110
  returned_audio_embed_sizes = []
111
- audio_frames_list = [] # Stores num_mel_frames * feat_stride for each audio item
112
 
113
  for audio_input_item in audios:
114
  if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2:
115
  raise ValueError(
116
  "Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)."
117
  )
118
- audio_data, sample_rate = audio_input_item
119
 
120
- if isinstance(audio_data, list): # Convert list to ndarray
121
  audio_data = np.array(audio_data, dtype=np.float32)
122
  if not isinstance(audio_data, np.ndarray):
123
  raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
124
 
125
- audio_embeds_np = self._extract_features(audio_data, sample_rate) # log_fbank
 
126
 
127
  num_mel_frames = audio_embeds_np.shape[0]
128
- current_audio_frames = num_mel_frames * self.feat_stride # Phi4M logic
129
 
130
  audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
131
 
@@ -145,12 +162,12 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
145
  max_audio_frames = tensor_audio_frames_list.max().item()
146
 
147
  returned_audio_attention_mask = None
148
- if max_audio_frames > 0: # Create mask only if there are frames
149
  if len(audios) > 1:
150
  returned_audio_attention_mask = torch.arange(0, max_audio_frames,
151
  device=tensor_audio_frames_list.device).unsqueeze(
152
  0) < tensor_audio_frames_list.unsqueeze(1)
153
- elif len(audios) == 1: # For batch size 1
154
  returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool,
155
  device=tensor_audio_frames_list.device)
156
 
@@ -164,50 +181,59 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
164
  return BatchFeature(data=data, tensor_type=return_tensors)
165
 
166
  def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
 
 
167
  if wav.ndim > 1:
168
  wav = np.squeeze(wav)
169
- if len(wav.shape) == 2: # stereo to mono
170
- wav = wav.mean(axis=1).astype(np.float32) # Ensure float32 after mean
171
-
172
- wav = wav.astype(np.float32) # Ensure wav is float32
173
-
174
- # Phi4M Resampling logic
175
- if fs > self.sampling_rate: # self.sampling_rate is 16000
176
- wav = scipy.signal.resample_poly(wav, self.sampling_rate, fs)
177
- fs = self.sampling_rate
178
- elif 8000 < fs < self.sampling_rate:
179
- wav = scipy.signal.resample_poly(wav, 8000, fs) # Resample to 8000 first
180
- fs = 8000
181
- elif fs < 8000 and fs > 0:
182
- logger.warning(f"Sample rate {fs} is less than 8000Hz. Resampling to 8000Hz.")
183
- wav = scipy.signal.resample_poly(wav, 8000, fs)
184
- fs = 8000
185
- elif fs <= 0:
186
- raise RuntimeError(f"Unsupported sample rate {fs}")
187
-
188
- if fs == 8000:
 
 
 
189
  if self._eightk_method == "resample":
190
- wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000) # Resample 8k to 16k
191
- fs = self.sampling_rate
192
- # If "fillzero", parameters for 8k will be used, and spectrum padded later.
193
- elif fs != self.sampling_rate: # Should be 16000 if not 8000 and _eightk_method != "resample"
194
  raise RuntimeError(
195
- f"Audio sample rate {fs} not supported after initial processing. Expected {self.sampling_rate} or 8000.")
196
 
197
  preemphasis_coeff = 0.97
198
 
199
- if fs == 8000:
 
200
  n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
201
- elif fs == 16000:
202
  n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
203
  else:
204
- raise RuntimeError(f"Inconsistent fs {fs} for parameter selection.")
205
 
206
  if len(wav) < win_length:
207
  wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
208
 
209
  num_frames = (wav.shape[0] - win_length) // hop_length + 1
210
  if num_frames <= 0:
 
 
 
211
  return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
212
 
213
  y_frames = np.array(
@@ -216,19 +242,16 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
216
  )
217
 
218
  _y_frames_rolled = np.roll(y_frames, 1, axis=1)
219
- _y_frames_rolled[:, 0] = _y_frames_rolled[:, 1] # Phi4M specific handling
220
  y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
221
 
222
  S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
223
 
224
- if fs == 8000 and self._eightk_method == "fillzero":
225
- # Pad spectrum to match 16kHz feature dimension (n_fft=512 -> 257 bins)
226
- # Current S has (256 // 2) + 1 = 129 bins
227
  target_bins = (512 // 2) + 1
228
- pad_width = target_bins - S.shape[1]
229
- # Phi4M: S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
230
- # This means take all but last bin from 8k spectrum, then pad.
231
- S_core = S[:, :-1]
232
  padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
233
  S = np.concatenate((S_core, padarray), axis=1)
234
 
@@ -238,15 +261,15 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
238
  def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
239
  spec = self._extract_spectrogram(wav, fs)
240
  if spec.shape[0] == 0:
 
241
  return np.zeros((0, self.feature_size), dtype=np.float32)
242
 
243
  spec_power = spec ** 2
244
- fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) # Clip at 1.0 before log (Phi4M)
245
  log_fbank = np.log(fbank_power).astype(np.float32)
246
  return log_fbank
247
 
248
  def _compute_audio_embed_size(self, audio_frames: int) -> int:
249
- # Phi4M's logic for compressed size
250
  integer = audio_frames // self.compression_rate
251
  remainder = audio_frames % self.compression_rate
252
  result = integer if remainder == 0 else integer + 1
@@ -257,6 +280,11 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
257
  return result
258
 
259
 
 
 
 
 
 
260
  # --- End of Refactored Audio Feature Extractor ---
261
 
262
 
 
16
  # Constants
17
  DEFAULT_SAMPLING_RATE = 16000
18
  DEFAULT_N_FFT = 512
19
+ DEFAULT_WIN_LENGTH = 400
20
+ DEFAULT_HOP_LENGTH = 160
21
  DEFAULT_N_MELS = 80
22
+ DEFAULT_COMPRESSION_RATE = 4 # Used for default in __init__
23
+ DEFAULT_QFORMER_RATE = 2 # Used for default in __init__ (as audio_downsample_rate)
24
+ DEFAULT_FEAT_STRIDE = 4 # Used for default in __init__
25
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
  DEFAULT_MAX_LENGTH = 16384
 
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
 
 
31
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
32
  fmax: Optional[float] = None) -> np.ndarray:
33
  """Create Mel filterbank for audio processing."""
34
  fmax = fmax or sampling_rate / 2.0
35
 
36
+ def hz_to_mel(f: float) -> float: # Slaney scale from Snippet B
37
  return 1127.0 * math.log(1 + f / 700.0)
38
 
39
  if fmin >= fmax:
40
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
41
 
42
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
43
+ freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of Slaney hz_to_mel
44
 
45
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
46
  bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
47
+ bins = np.clip(bins, 0, n_fft // 2)
48
 
49
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
50
  for m_idx in range(n_mels):
51
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
52
 
53
+ if center > left:
54
  filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
55
+ if right > center:
56
  filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
57
 
58
+ if left <= center <= right:
59
+ if filterbank.shape[1] > center:
 
 
60
  if (center > left and filterbank[m_idx, center] < 1.0 and center < right) or \
61
+ (left == center and center < right) or \
62
+ (right == center and left < center):
63
  filterbank[m_idx, center] = 1.0
64
+ elif left == center and right == center :
65
  filterbank[m_idx, center] = 1.0
 
66
  return filterbank
67
 
68
 
69
  # --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
70
+ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): # MODIFIED CLASS NAME AND __INIT__
71
  model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
72
 
73
+ def __init__(self,
74
+ audio_compression_rate: int = DEFAULT_COMPRESSION_RATE, # ADDED DEFAULT
75
+ audio_downsample_rate: int = DEFAULT_QFORMER_RATE, # ADDED DEFAULT (maps to qformer_rate)
76
+ audio_feat_stride: int = DEFAULT_FEAT_STRIDE, # ADDED DEFAULT
77
+ feature_size: int = DEFAULT_N_MELS, # Added default based on constants
78
+ sampling_rate: int = DEFAULT_SAMPLING_RATE, # Added default based on constants
79
+ padding_value: float = 0.0, # Added default
80
+ eightk_method: str = "fillzero", # Added default for this custom param
81
+ **kwargs):
82
+
83
+ # If feature_size, sampling_rate, padding_value are in kwargs, they will override defaults.
84
+ # The super().__init__ expects feature_size, sampling_rate, padding_value.
85
+ # We ensure they are passed, either from defaults or kwargs.
86
+ _feature_size = kwargs.pop("feature_size", feature_size)
87
+ _sampling_rate = kwargs.pop("sampling_rate", sampling_rate)
88
+ _padding_value = kwargs.pop("padding_value", padding_value)
89
+
90
+ super().__init__(feature_size=_feature_size, sampling_rate=_sampling_rate, padding_value=_padding_value,
91
+ **kwargs)
92
 
93
  self.compression_rate = audio_compression_rate
94
+ self.qformer_compression_rate = audio_downsample_rate
95
  self.feat_stride = audio_feat_stride
96
 
97
+ self._eightk_method = eightk_method # Use the argument, which has a default
98
+
99
+ # Ensure _sampling_rate is used for mel filterbank if it was overridden by kwargs for superclass
100
+ # However, Phi4M logic hardcodes 16000Hz for its mel parameters.
101
+ # self.sampling_rate from super() will be the target sampling rate.
102
+ if self.sampling_rate != 16000:
103
+ logger.warning(
104
+ f"The feature extractor's target sampling rate is {self.sampling_rate}, "
105
+ "but Phi4M-consistent Mel parameters are based on 16000 Hz. "
106
+ "This might lead to inconsistencies if the input audio is not resampled to 16000 Hz by this extractor."
107
+ )
108
 
 
 
109
  self._mel = create_mel_filterbank(
110
+ sampling_rate=16000, # Phi4M Mel params are for 16kHz.
111
+ n_fft=512,
112
+ n_mels=_feature_size, # Use the effective feature_size (should be 80)
113
+ fmin=0.0,
114
+ fmax=7690.0
115
  ).T
116
+ self._hamming400 = np.hamming(400)
117
+ self._hamming200 = np.hamming(200)
118
 
119
  def __call__(
120
  self,
121
+ audios: List[Union[AudioInput, Tuple[np.ndarray, int]]],
122
  return_tensors: Optional[Union[str, TensorType]] = None,
123
+ # sampling_rate: Optional[int] = None, # This was in original B, but Phi4M gets sr from AudioInput
124
  ):
125
  returned_input_audio_embeds = []
126
  returned_audio_embed_sizes = []
127
+ audio_frames_list = []
128
 
129
  for audio_input_item in audios:
130
  if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2:
131
  raise ValueError(
132
  "Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)."
133
  )
134
+ audio_data, sample_rate = audio_input_item # sample_rate is from the input audio
135
 
136
+ if isinstance(audio_data, list):
137
  audio_data = np.array(audio_data, dtype=np.float32)
138
  if not isinstance(audio_data, np.ndarray):
139
  raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
140
 
141
+ # _extract_features will handle resampling to self.sampling_rate (16000 Hz)
142
+ audio_embeds_np = self._extract_features(audio_data, sample_rate)
143
 
144
  num_mel_frames = audio_embeds_np.shape[0]
145
+ current_audio_frames = num_mel_frames * self.feat_stride
146
 
147
  audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
148
 
 
162
  max_audio_frames = tensor_audio_frames_list.max().item()
163
 
164
  returned_audio_attention_mask = None
165
+ if max_audio_frames > 0:
166
  if len(audios) > 1:
167
  returned_audio_attention_mask = torch.arange(0, max_audio_frames,
168
  device=tensor_audio_frames_list.device).unsqueeze(
169
  0) < tensor_audio_frames_list.unsqueeze(1)
170
+ elif len(audios) == 1:
171
  returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool,
172
  device=tensor_audio_frames_list.device)
173
 
 
181
  return BatchFeature(data=data, tensor_type=return_tensors)
182
 
183
  def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
184
+ # This method expects fs to be the original sampling rate of wav.
185
+ # It will resample to self.sampling_rate (16000Hz) or 8000Hz as needed.
186
  if wav.ndim > 1:
187
  wav = np.squeeze(wav)
188
+ if len(wav.shape) == 2:
189
+ wav = wav.mean(axis=1).astype(np.float32)
190
+
191
+ wav = wav.astype(np.float32)
192
+
193
+ current_fs = fs
194
+ if current_fs > self.sampling_rate: # self.sampling_rate is 16000
195
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, current_fs)
196
+ current_fs = self.sampling_rate
197
+ elif 8000 < current_fs < self.sampling_rate:
198
+ wav = scipy.signal.resample_poly(wav, 8000, current_fs)
199
+ current_fs = 8000
200
+ elif current_fs < 8000 and current_fs > 0:
201
+ logger.warning(f"Sample rate {current_fs} is less than 8000Hz. Resampling to 8000Hz.")
202
+ wav = scipy.signal.resample_poly(wav, 8000, current_fs)
203
+ current_fs = 8000
204
+ elif current_fs <= 0:
205
+ raise RuntimeError(f"Unsupported sample rate {current_fs}")
206
+
207
+ # After this block, current_fs is either 16000Hz or 8000Hz, or an error was raised.
208
+ # Or it's the original fs if it was already 16000 or 8000.
209
+
210
+ if current_fs == 8000:
211
  if self._eightk_method == "resample":
212
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000)
213
+ current_fs = self.sampling_rate
214
+ elif current_fs != self.sampling_rate:
215
+ # This case should ideally not be hit if logic above is correct and self.sampling_rate is 16000
216
  raise RuntimeError(
217
+ f"Audio sample rate {current_fs} not supported. Expected {self.sampling_rate} or 8000 for 8k methods.")
218
 
219
  preemphasis_coeff = 0.97
220
 
221
+ # current_fs is now the rate for STFT parameters (either 16000 or 8000 if fillzero)
222
+ if current_fs == 8000: # This implies _eightk_method == "fillzero"
223
  n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
224
+ elif current_fs == 16000: # This is the standard path
225
  n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
226
  else:
227
+ raise RuntimeError(f"Inconsistent fs {current_fs} for parameter selection. Should be 16000 or 8000.")
228
 
229
  if len(wav) < win_length:
230
  wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
231
 
232
  num_frames = (wav.shape[0] - win_length) // hop_length + 1
233
  if num_frames <= 0:
234
+ # For n_fft=512 (16k), output bins = 257. For n_fft=256 (8k), output bins = 129
235
+ # If fillzero for 8k, it will be padded to 257 later.
236
+ # So, the number of freq bins depends on n_fft here.
237
  return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
238
 
239
  y_frames = np.array(
 
242
  )
243
 
244
  _y_frames_rolled = np.roll(y_frames, 1, axis=1)
245
+ _y_frames_rolled[:, 0] = _y_frames_rolled[:, 1]
246
  y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
247
 
248
  S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
249
 
250
+ if current_fs == 8000 and self._eightk_method == "fillzero":
251
+ # S.shape[1] is 129 for n_fft=256. Target is 257 for n_fft=512 equivalence.
 
252
  target_bins = (512 // 2) + 1
253
+ S_core = S[:, :-1] # Drop 8kHz Nyquist bin (1 bin)
254
+ # Pad to target_bins. Number of columns to add: target_bins - S_core.shape[1]
 
 
255
  padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
256
  S = np.concatenate((S_core, padarray), axis=1)
257
 
 
261
  def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
262
  spec = self._extract_spectrogram(wav, fs)
263
  if spec.shape[0] == 0:
264
+ # self.feature_size is n_mels (e.g. 80)
265
  return np.zeros((0, self.feature_size), dtype=np.float32)
266
 
267
  spec_power = spec ** 2
268
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
269
  log_fbank = np.log(fbank_power).astype(np.float32)
270
  return log_fbank
271
 
272
  def _compute_audio_embed_size(self, audio_frames: int) -> int:
 
273
  integer = audio_frames // self.compression_rate
274
  remainder = audio_frames % self.compression_rate
275
  result = integer if remainder == 0 else integer + 1
 
280
  return result
281
 
282
 
283
+ # The rest of your script (Gemma3ImagesKwargs, Gemma3ProcessorKwargs, Gemma3OmniProcessor) follows...
284
+ # Make sure this Gemma3AudioFeatureExtractor class replaces the old one or
285
+ # is correctly registered/named if your AutoProcessor setup relies on a specific name.
286
+
287
+
288
  # --- End of Refactored Audio Feature Extractor ---
289
 
290