Update processing_gemma3_omni.py
Browse files- processing_gemma3_omni.py +113 -85
processing_gemma3_omni.py
CHANGED
@@ -16,116 +16,133 @@ from transformers.utils import TensorType, to_py_obj, logging
|
|
16 |
# Constants
|
17 |
DEFAULT_SAMPLING_RATE = 16000
|
18 |
DEFAULT_N_FFT = 512
|
19 |
-
DEFAULT_WIN_LENGTH = 400
|
20 |
-
DEFAULT_HOP_LENGTH = 160
|
21 |
DEFAULT_N_MELS = 80
|
22 |
-
DEFAULT_COMPRESSION_RATE = 4
|
23 |
-
DEFAULT_QFORMER_RATE = 2
|
24 |
-
DEFAULT_FEAT_STRIDE = 4
|
25 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
26 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
27 |
DEFAULT_MAX_LENGTH = 16384
|
28 |
-
# LOG_MEL_CLIP_EPSILON = 1e-5 # Original B's constant, A clips at 1.0
|
29 |
|
30 |
logger = logging.get_logger(__name__)
|
31 |
|
32 |
-
|
33 |
-
# This create_mel_filterbank function is from your original Snippet B.
|
34 |
-
# It will be used by the Gemma3AudioFeatureExtractor.
|
35 |
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
|
36 |
fmax: Optional[float] = None) -> np.ndarray:
|
37 |
"""Create Mel filterbank for audio processing."""
|
38 |
fmax = fmax or sampling_rate / 2.0
|
39 |
|
40 |
-
def hz_to_mel(f: float) -> float:
|
41 |
return 1127.0 * math.log(1 + f / 700.0)
|
42 |
|
43 |
if fmin >= fmax:
|
44 |
raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
|
45 |
|
46 |
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
47 |
-
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
|
48 |
|
49 |
freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
|
50 |
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
|
51 |
-
bins = np.clip(bins, 0, n_fft // 2)
|
52 |
|
53 |
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
|
54 |
for m_idx in range(n_mels):
|
55 |
left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
|
56 |
|
57 |
-
if center > left:
|
58 |
filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
|
59 |
-
if right > center:
|
60 |
filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
if left <= center <= right: # Check if center is within the bounds of the filter
|
65 |
-
if filterbank.shape[1] > center: # Check if center index is within filterbank columns
|
66 |
if (center > left and filterbank[m_idx, center] < 1.0 and center < right) or \
|
67 |
-
|
68 |
-
|
69 |
filterbank[m_idx, center] = 1.0
|
70 |
-
elif left == center and right == center:
|
71 |
filterbank[m_idx, center] = 1.0
|
72 |
-
|
73 |
return filterbank
|
74 |
|
75 |
|
76 |
# --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
|
77 |
-
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
78 |
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
|
79 |
|
80 |
-
def __init__(self,
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
self.compression_rate = audio_compression_rate
|
87 |
-
self.qformer_compression_rate = audio_downsample_rate
|
88 |
self.feat_stride = audio_feat_stride
|
89 |
|
90 |
-
self._eightk_method =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
# Using the provided create_mel_filterbank (Slaney scale)
|
93 |
-
# Parameters for Mel filterbank match Phi4M's speechlib_mel call
|
94 |
self._mel = create_mel_filterbank(
|
95 |
-
sampling_rate=16000, #
|
96 |
-
n_fft=512,
|
97 |
-
n_mels=
|
98 |
-
fmin=0.0,
|
99 |
-
fmax=7690.0
|
100 |
).T
|
101 |
-
self._hamming400 = np.hamming(400)
|
102 |
-
self._hamming200 = np.hamming(200)
|
103 |
|
104 |
def __call__(
|
105 |
self,
|
106 |
-
audios: List[Union[AudioInput, Tuple[np.ndarray, int]]],
|
107 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
|
108 |
):
|
109 |
returned_input_audio_embeds = []
|
110 |
returned_audio_embed_sizes = []
|
111 |
-
audio_frames_list = []
|
112 |
|
113 |
for audio_input_item in audios:
|
114 |
if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2:
|
115 |
raise ValueError(
|
116 |
"Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)."
|
117 |
)
|
118 |
-
audio_data, sample_rate = audio_input_item
|
119 |
|
120 |
-
if isinstance(audio_data, list):
|
121 |
audio_data = np.array(audio_data, dtype=np.float32)
|
122 |
if not isinstance(audio_data, np.ndarray):
|
123 |
raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
|
124 |
|
125 |
-
|
|
|
126 |
|
127 |
num_mel_frames = audio_embeds_np.shape[0]
|
128 |
-
current_audio_frames = num_mel_frames * self.feat_stride
|
129 |
|
130 |
audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
|
131 |
|
@@ -145,12 +162,12 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
145 |
max_audio_frames = tensor_audio_frames_list.max().item()
|
146 |
|
147 |
returned_audio_attention_mask = None
|
148 |
-
if max_audio_frames > 0:
|
149 |
if len(audios) > 1:
|
150 |
returned_audio_attention_mask = torch.arange(0, max_audio_frames,
|
151 |
device=tensor_audio_frames_list.device).unsqueeze(
|
152 |
0) < tensor_audio_frames_list.unsqueeze(1)
|
153 |
-
elif len(audios) == 1:
|
154 |
returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool,
|
155 |
device=tensor_audio_frames_list.device)
|
156 |
|
@@ -164,50 +181,59 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
164 |
return BatchFeature(data=data, tensor_type=return_tensors)
|
165 |
|
166 |
def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
|
|
|
|
167 |
if wav.ndim > 1:
|
168 |
wav = np.squeeze(wav)
|
169 |
-
if len(wav.shape) == 2:
|
170 |
-
wav = wav.mean(axis=1).astype(np.float32)
|
171 |
-
|
172 |
-
wav = wav.astype(np.float32)
|
173 |
-
|
174 |
-
|
175 |
-
if
|
176 |
-
wav = scipy.signal.resample_poly(wav, self.sampling_rate,
|
177 |
-
|
178 |
-
elif 8000 <
|
179 |
-
wav = scipy.signal.resample_poly(wav, 8000,
|
180 |
-
|
181 |
-
elif
|
182 |
-
logger.warning(f"Sample rate {
|
183 |
-
wav = scipy.signal.resample_poly(wav, 8000,
|
184 |
-
|
185 |
-
elif
|
186 |
-
raise RuntimeError(f"Unsupported sample rate {
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
189 |
if self._eightk_method == "resample":
|
190 |
-
wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000)
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
raise RuntimeError(
|
195 |
-
f"Audio sample rate {
|
196 |
|
197 |
preemphasis_coeff = 0.97
|
198 |
|
199 |
-
|
|
|
200 |
n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
|
201 |
-
elif
|
202 |
n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
|
203 |
else:
|
204 |
-
raise RuntimeError(f"Inconsistent fs {
|
205 |
|
206 |
if len(wav) < win_length:
|
207 |
wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
|
208 |
|
209 |
num_frames = (wav.shape[0] - win_length) // hop_length + 1
|
210 |
if num_frames <= 0:
|
|
|
|
|
|
|
211 |
return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
|
212 |
|
213 |
y_frames = np.array(
|
@@ -216,19 +242,16 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
216 |
)
|
217 |
|
218 |
_y_frames_rolled = np.roll(y_frames, 1, axis=1)
|
219 |
-
_y_frames_rolled[:, 0] = _y_frames_rolled[:, 1]
|
220 |
y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
|
221 |
|
222 |
S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
|
223 |
|
224 |
-
if
|
225 |
-
#
|
226 |
-
# Current S has (256 // 2) + 1 = 129 bins
|
227 |
target_bins = (512 // 2) + 1
|
228 |
-
|
229 |
-
#
|
230 |
-
# This means take all but last bin from 8k spectrum, then pad.
|
231 |
-
S_core = S[:, :-1]
|
232 |
padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
|
233 |
S = np.concatenate((S_core, padarray), axis=1)
|
234 |
|
@@ -238,15 +261,15 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
238 |
def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
239 |
spec = self._extract_spectrogram(wav, fs)
|
240 |
if spec.shape[0] == 0:
|
|
|
241 |
return np.zeros((0, self.feature_size), dtype=np.float32)
|
242 |
|
243 |
spec_power = spec ** 2
|
244 |
-
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
|
245 |
log_fbank = np.log(fbank_power).astype(np.float32)
|
246 |
return log_fbank
|
247 |
|
248 |
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
249 |
-
# Phi4M's logic for compressed size
|
250 |
integer = audio_frames // self.compression_rate
|
251 |
remainder = audio_frames % self.compression_rate
|
252 |
result = integer if remainder == 0 else integer + 1
|
@@ -257,6 +280,11 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
257 |
return result
|
258 |
|
259 |
|
|
|
|
|
|
|
|
|
|
|
260 |
# --- End of Refactored Audio Feature Extractor ---
|
261 |
|
262 |
|
|
|
16 |
# Constants
|
17 |
DEFAULT_SAMPLING_RATE = 16000
|
18 |
DEFAULT_N_FFT = 512
|
19 |
+
DEFAULT_WIN_LENGTH = 400
|
20 |
+
DEFAULT_HOP_LENGTH = 160
|
21 |
DEFAULT_N_MELS = 80
|
22 |
+
DEFAULT_COMPRESSION_RATE = 4 # Used for default in __init__
|
23 |
+
DEFAULT_QFORMER_RATE = 2 # Used for default in __init__ (as audio_downsample_rate)
|
24 |
+
DEFAULT_FEAT_STRIDE = 4 # Used for default in __init__
|
25 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
26 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
27 |
DEFAULT_MAX_LENGTH = 16384
|
|
|
28 |
|
29 |
logger = logging.get_logger(__name__)
|
30 |
|
|
|
|
|
|
|
31 |
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
|
32 |
fmax: Optional[float] = None) -> np.ndarray:
|
33 |
"""Create Mel filterbank for audio processing."""
|
34 |
fmax = fmax or sampling_rate / 2.0
|
35 |
|
36 |
+
def hz_to_mel(f: float) -> float: # Slaney scale from Snippet B
|
37 |
return 1127.0 * math.log(1 + f / 700.0)
|
38 |
|
39 |
if fmin >= fmax:
|
40 |
raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
|
41 |
|
42 |
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
43 |
+
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of Slaney hz_to_mel
|
44 |
|
45 |
freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
|
46 |
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
|
47 |
+
bins = np.clip(bins, 0, n_fft // 2)
|
48 |
|
49 |
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
|
50 |
for m_idx in range(n_mels):
|
51 |
left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
|
52 |
|
53 |
+
if center > left:
|
54 |
filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
|
55 |
+
if right > center:
|
56 |
filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
|
57 |
|
58 |
+
if left <= center <= right:
|
59 |
+
if filterbank.shape[1] > center:
|
|
|
|
|
60 |
if (center > left and filterbank[m_idx, center] < 1.0 and center < right) or \
|
61 |
+
(left == center and center < right) or \
|
62 |
+
(right == center and left < center):
|
63 |
filterbank[m_idx, center] = 1.0
|
64 |
+
elif left == center and right == center :
|
65 |
filterbank[m_idx, center] = 1.0
|
|
|
66 |
return filterbank
|
67 |
|
68 |
|
69 |
# --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
|
70 |
+
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): # MODIFIED CLASS NAME AND __INIT__
|
71 |
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
|
72 |
|
73 |
+
def __init__(self,
|
74 |
+
audio_compression_rate: int = DEFAULT_COMPRESSION_RATE, # ADDED DEFAULT
|
75 |
+
audio_downsample_rate: int = DEFAULT_QFORMER_RATE, # ADDED DEFAULT (maps to qformer_rate)
|
76 |
+
audio_feat_stride: int = DEFAULT_FEAT_STRIDE, # ADDED DEFAULT
|
77 |
+
feature_size: int = DEFAULT_N_MELS, # Added default based on constants
|
78 |
+
sampling_rate: int = DEFAULT_SAMPLING_RATE, # Added default based on constants
|
79 |
+
padding_value: float = 0.0, # Added default
|
80 |
+
eightk_method: str = "fillzero", # Added default for this custom param
|
81 |
+
**kwargs):
|
82 |
+
|
83 |
+
# If feature_size, sampling_rate, padding_value are in kwargs, they will override defaults.
|
84 |
+
# The super().__init__ expects feature_size, sampling_rate, padding_value.
|
85 |
+
# We ensure they are passed, either from defaults or kwargs.
|
86 |
+
_feature_size = kwargs.pop("feature_size", feature_size)
|
87 |
+
_sampling_rate = kwargs.pop("sampling_rate", sampling_rate)
|
88 |
+
_padding_value = kwargs.pop("padding_value", padding_value)
|
89 |
+
|
90 |
+
super().__init__(feature_size=_feature_size, sampling_rate=_sampling_rate, padding_value=_padding_value,
|
91 |
+
**kwargs)
|
92 |
|
93 |
self.compression_rate = audio_compression_rate
|
94 |
+
self.qformer_compression_rate = audio_downsample_rate
|
95 |
self.feat_stride = audio_feat_stride
|
96 |
|
97 |
+
self._eightk_method = eightk_method # Use the argument, which has a default
|
98 |
+
|
99 |
+
# Ensure _sampling_rate is used for mel filterbank if it was overridden by kwargs for superclass
|
100 |
+
# However, Phi4M logic hardcodes 16000Hz for its mel parameters.
|
101 |
+
# self.sampling_rate from super() will be the target sampling rate.
|
102 |
+
if self.sampling_rate != 16000:
|
103 |
+
logger.warning(
|
104 |
+
f"The feature extractor's target sampling rate is {self.sampling_rate}, "
|
105 |
+
"but Phi4M-consistent Mel parameters are based on 16000 Hz. "
|
106 |
+
"This might lead to inconsistencies if the input audio is not resampled to 16000 Hz by this extractor."
|
107 |
+
)
|
108 |
|
|
|
|
|
109 |
self._mel = create_mel_filterbank(
|
110 |
+
sampling_rate=16000, # Phi4M Mel params are for 16kHz.
|
111 |
+
n_fft=512,
|
112 |
+
n_mels=_feature_size, # Use the effective feature_size (should be 80)
|
113 |
+
fmin=0.0,
|
114 |
+
fmax=7690.0
|
115 |
).T
|
116 |
+
self._hamming400 = np.hamming(400)
|
117 |
+
self._hamming200 = np.hamming(200)
|
118 |
|
119 |
def __call__(
|
120 |
self,
|
121 |
+
audios: List[Union[AudioInput, Tuple[np.ndarray, int]]],
|
122 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
123 |
+
# sampling_rate: Optional[int] = None, # This was in original B, but Phi4M gets sr from AudioInput
|
124 |
):
|
125 |
returned_input_audio_embeds = []
|
126 |
returned_audio_embed_sizes = []
|
127 |
+
audio_frames_list = []
|
128 |
|
129 |
for audio_input_item in audios:
|
130 |
if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2:
|
131 |
raise ValueError(
|
132 |
"Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)."
|
133 |
)
|
134 |
+
audio_data, sample_rate = audio_input_item # sample_rate is from the input audio
|
135 |
|
136 |
+
if isinstance(audio_data, list):
|
137 |
audio_data = np.array(audio_data, dtype=np.float32)
|
138 |
if not isinstance(audio_data, np.ndarray):
|
139 |
raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
|
140 |
|
141 |
+
# _extract_features will handle resampling to self.sampling_rate (16000 Hz)
|
142 |
+
audio_embeds_np = self._extract_features(audio_data, sample_rate)
|
143 |
|
144 |
num_mel_frames = audio_embeds_np.shape[0]
|
145 |
+
current_audio_frames = num_mel_frames * self.feat_stride
|
146 |
|
147 |
audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
|
148 |
|
|
|
162 |
max_audio_frames = tensor_audio_frames_list.max().item()
|
163 |
|
164 |
returned_audio_attention_mask = None
|
165 |
+
if max_audio_frames > 0:
|
166 |
if len(audios) > 1:
|
167 |
returned_audio_attention_mask = torch.arange(0, max_audio_frames,
|
168 |
device=tensor_audio_frames_list.device).unsqueeze(
|
169 |
0) < tensor_audio_frames_list.unsqueeze(1)
|
170 |
+
elif len(audios) == 1:
|
171 |
returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool,
|
172 |
device=tensor_audio_frames_list.device)
|
173 |
|
|
|
181 |
return BatchFeature(data=data, tensor_type=return_tensors)
|
182 |
|
183 |
def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
184 |
+
# This method expects fs to be the original sampling rate of wav.
|
185 |
+
# It will resample to self.sampling_rate (16000Hz) or 8000Hz as needed.
|
186 |
if wav.ndim > 1:
|
187 |
wav = np.squeeze(wav)
|
188 |
+
if len(wav.shape) == 2:
|
189 |
+
wav = wav.mean(axis=1).astype(np.float32)
|
190 |
+
|
191 |
+
wav = wav.astype(np.float32)
|
192 |
+
|
193 |
+
current_fs = fs
|
194 |
+
if current_fs > self.sampling_rate: # self.sampling_rate is 16000
|
195 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, current_fs)
|
196 |
+
current_fs = self.sampling_rate
|
197 |
+
elif 8000 < current_fs < self.sampling_rate:
|
198 |
+
wav = scipy.signal.resample_poly(wav, 8000, current_fs)
|
199 |
+
current_fs = 8000
|
200 |
+
elif current_fs < 8000 and current_fs > 0:
|
201 |
+
logger.warning(f"Sample rate {current_fs} is less than 8000Hz. Resampling to 8000Hz.")
|
202 |
+
wav = scipy.signal.resample_poly(wav, 8000, current_fs)
|
203 |
+
current_fs = 8000
|
204 |
+
elif current_fs <= 0:
|
205 |
+
raise RuntimeError(f"Unsupported sample rate {current_fs}")
|
206 |
+
|
207 |
+
# After this block, current_fs is either 16000Hz or 8000Hz, or an error was raised.
|
208 |
+
# Or it's the original fs if it was already 16000 or 8000.
|
209 |
+
|
210 |
+
if current_fs == 8000:
|
211 |
if self._eightk_method == "resample":
|
212 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000)
|
213 |
+
current_fs = self.sampling_rate
|
214 |
+
elif current_fs != self.sampling_rate:
|
215 |
+
# This case should ideally not be hit if logic above is correct and self.sampling_rate is 16000
|
216 |
raise RuntimeError(
|
217 |
+
f"Audio sample rate {current_fs} not supported. Expected {self.sampling_rate} or 8000 for 8k methods.")
|
218 |
|
219 |
preemphasis_coeff = 0.97
|
220 |
|
221 |
+
# current_fs is now the rate for STFT parameters (either 16000 or 8000 if fillzero)
|
222 |
+
if current_fs == 8000: # This implies _eightk_method == "fillzero"
|
223 |
n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
|
224 |
+
elif current_fs == 16000: # This is the standard path
|
225 |
n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
|
226 |
else:
|
227 |
+
raise RuntimeError(f"Inconsistent fs {current_fs} for parameter selection. Should be 16000 or 8000.")
|
228 |
|
229 |
if len(wav) < win_length:
|
230 |
wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
|
231 |
|
232 |
num_frames = (wav.shape[0] - win_length) // hop_length + 1
|
233 |
if num_frames <= 0:
|
234 |
+
# For n_fft=512 (16k), output bins = 257. For n_fft=256 (8k), output bins = 129
|
235 |
+
# If fillzero for 8k, it will be padded to 257 later.
|
236 |
+
# So, the number of freq bins depends on n_fft here.
|
237 |
return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
|
238 |
|
239 |
y_frames = np.array(
|
|
|
242 |
)
|
243 |
|
244 |
_y_frames_rolled = np.roll(y_frames, 1, axis=1)
|
245 |
+
_y_frames_rolled[:, 0] = _y_frames_rolled[:, 1]
|
246 |
y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
|
247 |
|
248 |
S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
|
249 |
|
250 |
+
if current_fs == 8000 and self._eightk_method == "fillzero":
|
251 |
+
# S.shape[1] is 129 for n_fft=256. Target is 257 for n_fft=512 equivalence.
|
|
|
252 |
target_bins = (512 // 2) + 1
|
253 |
+
S_core = S[:, :-1] # Drop 8kHz Nyquist bin (1 bin)
|
254 |
+
# Pad to target_bins. Number of columns to add: target_bins - S_core.shape[1]
|
|
|
|
|
255 |
padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
|
256 |
S = np.concatenate((S_core, padarray), axis=1)
|
257 |
|
|
|
261 |
def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
262 |
spec = self._extract_spectrogram(wav, fs)
|
263 |
if spec.shape[0] == 0:
|
264 |
+
# self.feature_size is n_mels (e.g. 80)
|
265 |
return np.zeros((0, self.feature_size), dtype=np.float32)
|
266 |
|
267 |
spec_power = spec ** 2
|
268 |
+
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
|
269 |
log_fbank = np.log(fbank_power).astype(np.float32)
|
270 |
return log_fbank
|
271 |
|
272 |
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
|
|
273 |
integer = audio_frames // self.compression_rate
|
274 |
remainder = audio_frames % self.compression_rate
|
275 |
result = integer if remainder == 0 else integer + 1
|
|
|
280 |
return result
|
281 |
|
282 |
|
283 |
+
# The rest of your script (Gemma3ImagesKwargs, Gemma3ProcessorKwargs, Gemma3OmniProcessor) follows...
|
284 |
+
# Make sure this Gemma3AudioFeatureExtractor class replaces the old one or
|
285 |
+
# is correctly registered/named if your AutoProcessor setup relies on a specific name.
|
286 |
+
|
287 |
+
|
288 |
# --- End of Refactored Audio Feature Extractor ---
|
289 |
|
290 |
|