voidful commited on
Commit
d0ae9f9
·
verified ·
1 Parent(s): 6cb34fb

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +115 -111
processing_gemma3_omni.py CHANGED
@@ -1,5 +1,5 @@
1
  import re
2
- from typing import List, Optional, Union, Dict, Any
3
 
4
  import math
5
  import numpy as np
@@ -15,17 +15,16 @@ AudioInput = Union[np.ndarray, List[float], Tuple[np.ndarray, int]]
15
 
16
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
17
  from transformers.feature_extraction_utils import BatchFeature
18
- from transformers.processing_utils import ProcessorMixin, ProcessingKwargs
19
  from transformers.utils import TensorType, to_py_obj, logging
20
  # For AutoImageProcessor, AutoTokenizer if needed for default loading
21
  from transformers import AutoImageProcessor, AutoTokenizer
22
 
23
-
24
  # Constants (as defined before)
25
  DEFAULT_SAMPLING_RATE = 16000
26
  DEFAULT_N_FFT = 512
27
- DEFAULT_WIN_LENGTH = 400 # Will be n_fft if None in __init__
28
- DEFAULT_HOP_LENGTH = 160 # Will be win_length // 4 if None in __init__
29
  DEFAULT_N_MELS = 80
30
  DEFAULT_COMPRESSION_RATE = 4
31
  DEFAULT_QFORMER_RATE = 2
@@ -48,15 +47,15 @@ def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: flo
48
  if fmin >= fmax:
49
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
50
 
51
- def hz_to_mel(f: float) -> float: # Using HTK formula (as in librosa default)
52
  return 2595.0 * math.log10(1 + f / 700.0)
53
 
54
  def mel_to_hz(mel: float) -> float:
55
- return 700.0 * (10**(mel / 2595.0) - 1)
56
 
57
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
58
  freq_points = mel_to_hz(mel_points)
59
-
60
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
61
  bins = np.floor((n_fft / 2.0) * freq_points / (sampling_rate / 2.0)).astype(int)
62
  bins = np.clip(bins, 0, n_fft // 2)
@@ -64,43 +63,44 @@ def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: flo
64
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
65
  for m in range(n_mels):
66
  left, center, right = bins[m], bins[m + 1], bins[m + 2]
67
-
68
  # Simplified triangle creation logic (more robust versions exist in libraries like librosa)
69
  if center > left:
70
- filterbank[m, left:center+1] = (np.arange(left, center + 1) - left) / (center - left)
71
  if right > center:
72
- filterbank[m, center:right+1] = (right - np.arange(center, right + 1)) / (right - center)
73
  # Ensure peak is 1 if multiple points coincide at center (can happen with narrow filters/low resolution)
74
- if left <= center <= right and filterbank[m,center] < 1.0 and (center > left or center < right) : #check if it's a valid point for a peak
 
75
  # if filterbank[m,center] is not properly set to 1 by slopes (e.g. left==center or right==center)
76
- filterbank[m,center] = 1.0
77
- if left == center and right > center : # only falling slope
78
- # Ensure it doesn't double-dip if already set
79
- pass
80
- elif right == center and left < center: # only rising slope
81
- pass
82
-
83
 
84
  return filterbank
85
 
 
86
  # Gemma3AudioFeatureExtractor class (assuming it's correctly defined from previous response)
87
  # ... (Gemma3AudioFeatureExtractor class from the previous corrected response) ...
88
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
89
- model_input_names = ["audio_values", "audio_attention_mask"]
90
 
91
  def __init__(
92
  self,
93
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
94
  qformer_rate: int = DEFAULT_QFORMER_RATE,
95
  feat_stride: int = DEFAULT_FEAT_STRIDE,
96
- sampling_rate: int = DEFAULT_SAMPLING_RATE,
97
  n_fft: int = DEFAULT_N_FFT,
98
- win_length: Optional[int] = None,
99
- hop_length: Optional[int] = None,
100
  n_mels: int = DEFAULT_N_MELS,
101
  f_min: float = 0.0,
102
  f_max: Optional[float] = None,
103
- padding_value: float = 0.0,
104
  **kwargs
105
  ):
106
  super().__init__(feature_size=n_mels, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
@@ -128,16 +128,16 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
128
  def __call__(
129
  self,
130
  audios: Union[AudioInput, List[AudioInput]],
131
- sampling_rate: Optional[int] = None,
132
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
133
  ) -> BatchFeature:
134
-
135
  if not isinstance(audios, list):
136
  audios = [audios]
137
 
138
  processed_mel_spectrograms: List[torch.Tensor] = []
139
  actual_mel_lengths: List[int] = []
140
- downstream_sizes_for_token_calc: List[torch.Tensor] = []
141
  downstream_frames_scaled_for_token_calc: List[int] = []
142
 
143
  for audio_input_item in audios:
@@ -161,11 +161,11 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
161
  )
162
 
163
  processed_wav = self._preprocess_audio(current_wav_array, source_sr)
164
- mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
165
-
166
  feature_tensor = torch.from_numpy(mel_spectrogram)
167
  processed_mel_spectrograms.append(feature_tensor)
168
- actual_mel_lengths.append(feature_tensor.shape[0])
169
 
170
  downstream_sizes_for_token_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
171
  downstream_frames_scaled_for_token_calc.append(feature_tensor.shape[0] * self.feat_stride)
@@ -173,16 +173,17 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
173
  audio_values = pad_sequence(processed_mel_spectrograms, batch_first=True, padding_value=self.padding_value)
174
  max_mel_len = audio_values.shape[1]
175
  lengths_tensor = torch.tensor(actual_mel_lengths, dtype=torch.long)
176
- audio_attention_mask = torch.arange(max_mel_len).unsqueeze(0).expand(len(audios), -1) < lengths_tensor.unsqueeze(1)
177
-
 
178
  output_data = {
179
  "audio_values": audio_values,
180
  "audio_attention_mask": audio_attention_mask
181
  }
182
-
183
  if downstream_sizes_for_token_calc:
184
- output_data["audio_token_calc_sizes"] = torch.stack(downstream_sizes_for_token_calc)
185
-
186
  return BatchFeature(data=output_data, tensor_type=return_tensors)
187
 
188
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
@@ -190,22 +191,22 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
190
  if np.issubdtype(wav.dtype, np.integer):
191
  max_val = np.iinfo(wav.dtype).max
192
  wav = wav.astype(np.float32) / max_val
193
- else:
194
  wav = wav.astype(np.float32)
195
-
196
  if wav.ndim > 1:
197
  wav = wav.mean(axis=0)
198
-
199
  if source_sr != self.sampling_rate:
200
  gcd = math.gcd(self.sampling_rate, source_sr)
201
  up_factor = self.sampling_rate // gcd
202
  down_factor = source_sr // gcd
203
- if up_factor != down_factor:
204
- logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
205
- wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
206
-
207
  norm_factor = np.abs(wav).max()
208
- if norm_factor > 1e-9:
209
  wav = wav / norm_factor
210
  return wav
211
 
@@ -216,7 +217,8 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
216
 
217
  num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
218
  if num_frames <= 0:
219
- logger.warning(f"Audio of length {len(wav)} is too short to produce frames with win_length {self.win_length} and hop_length {self.hop_length}. Returning empty mel spectrogram.")
 
220
  return np.zeros((0, self.n_mels), dtype=np.float32)
221
 
222
  frames = np.lib.stride_tricks.as_strided(
@@ -225,21 +227,22 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
225
  strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
226
  writeable=False
227
  )
228
-
229
  windowed_frames = frames * self.window
230
  stft_matrix = np.fft.rfft(windowed_frames, n=self.n_fft, axis=-1)
231
- powers = np.abs(stft_matrix)**2
232
  mel_spectrogram = np.dot(powers, self.mel_filterbank)
233
  mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
234
  log_mel_spectrogram = np.log(mel_spectrogram)
235
-
236
  return log_mel_spectrogram.astype(np.float32)
237
 
238
  def _calculate_embed_length(self, frame_count: int) -> int:
239
  compressed = math.ceil(frame_count / self.compression_rate)
240
  return math.ceil(compressed / self.qformer_rate)
241
 
242
- class Gemma3DummyProcessorKwargs(ProcessingKwargs, total=False): # Dummy for testing structure
 
243
  images_kwargs: Dict[str, Any]
244
  audio_kwargs: Dict[str, Any]
245
  text_kwargs: Dict[str, Any]
@@ -249,24 +252,25 @@ class Gemma3DummyProcessorKwargs(ProcessingKwargs, total=False): # Dummy for tes
249
  "audio_kwargs": {}
250
  }
251
 
 
252
  class Gemma3OmniProcessor(ProcessorMixin):
253
  attributes = ["image_processor", "audio_processor", "tokenizer"]
254
  # Define class attributes for ProcessorMixin to find/use them
255
  image_processor_class = "AutoImageProcessor" # Or the specific class string if not auto
256
- audio_processor_class = Gemma3AudioFeatureExtractor # Correctly points to your custom class
257
- tokenizer_class = "AutoTokenizer" # Or the specific class string
258
 
259
  # valid_kwargs was in user's code, its role depends on ProcessorMixin internal usage
260
- valid_kwargs = ["chat_template", "image_seq_length"]
261
 
262
  def __init__(
263
  self,
264
- tokenizer,
265
  audio_processor: Optional[Union[Gemma3AudioFeatureExtractor, Dict]] = None,
266
- image_processor = None,
267
  chat_template=None,
268
  image_seq_length: int = 256,
269
- audio_prompt_compression_rate: int = 8,
270
  audio_prompt_qformer_rate: int = 1,
271
  audio_prompt_feat_stride: int = 1,
272
  audio_placeholder_token: str = "<|audio_placeholder|>",
@@ -279,48 +283,50 @@ class Gemma3OmniProcessor(ProcessorMixin):
279
  audio_processor = Gemma3AudioFeatureExtractor()
280
  elif isinstance(audio_processor, Dict):
281
  audio_processor = Gemma3AudioFeatureExtractor(**audio_processor)
282
- elif not isinstance(audio_processor, Gemma3AudioFeatureExtractor): # Check type if instance is passed
283
- raise TypeError(f"audio_processor must be an instance of Gemma3AudioFeatureExtractor or a config dict, got {type(audio_processor)}")
 
284
 
285
  # Handle image_processor similarly if it can be None or a dict
286
  if image_processor is None and self.image_processor_class:
287
- # This is a basic way; from_pretrained usually handles complex loading
288
  if isinstance(self.image_processor_class, str) and self.image_processor_class == "AutoImageProcessor":
289
- logger.info(f"Attempting to load a default {self.image_processor_class}. This might require a default model name or fail.")
 
290
  # image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") # Example default
291
  # else if self.image_processor_class is an actual class, instantiate it.
292
  elif isinstance(image_processor, Dict):
293
  # image_processor = AutoImageProcessor.from_config(config_class(**image_processor)) # Example
294
- pass # Actual instantiation from dict would be more complex
295
 
296
  # Ensure tokenizer is an instantiated object
297
- if isinstance(tokenizer, str): # If tokenizer is a string (model name/path)
298
  logger.info(f"Loading tokenizer from {tokenizer}")
299
  # tokenizer = AutoTokenizer.from_pretrained(tokenizer) # This is how it's usually done
300
  elif tokenizer is None:
301
- raise ValueError("A tokenizer instance or identifier must be provided.")
302
-
303
 
304
  super().__init__(
305
  image_processor=image_processor,
306
  audio_processor=audio_processor,
307
  tokenizer=tokenizer,
308
  chat_template=chat_template,
309
- **kwargs # Pass other kwargs to super
310
  )
311
-
312
  self.image_seq_length = image_seq_length
313
- self.image_token_id = getattr(self.tokenizer, "image_token_id", self.tokenizer.unk_token_id if hasattr(self.tokenizer, "unk_token_id") else None)
314
- self.boi_token = getattr(self.tokenizer, "boi_token", "<|image|>")
 
315
  self.image_token = getattr(self.tokenizer, "image_token", "<|image|>")
316
- self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
317
 
318
  self.audio_placeholder_token = audio_placeholder_token
319
  self.audio_soft_token_str = audio_soft_token_str
320
-
321
  self.audio_soft_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_soft_token_str)
322
- if self.audio_soft_token_id == self.tokenizer.unk_token_id: # Check if UNK
323
- logger.warning(
324
  f"The audio soft token string '{self.audio_soft_token_str}' maps to UNK token (ID: {self.audio_soft_token_id}). "
325
  "Ensure it is added to the tokenizer's vocabulary as a special token."
326
  )
@@ -331,7 +337,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
331
  self.audio_prompt_qformer_rate = audio_prompt_qformer_rate
332
  self.audio_prompt_feat_stride = audio_prompt_feat_stride
333
 
334
-
335
  def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_passed_to_call):
336
  final_kwargs = {}
337
  # Initialize with _defaults from the Kwargs class
@@ -342,24 +347,24 @@ class Gemma3OmniProcessor(ProcessorMixin):
342
 
343
  # Override with tokenizer's init_kwargs if they exist for a given key
344
  for modality_key, modality_dict in final_kwargs.items():
345
- for key in list(modality_dict.keys()):
346
  if key in tokenizer_init_kwargs:
347
  modality_dict[key] = tokenizer_init_kwargs[key]
348
-
349
  # Override with kwargs passed directly to __call__
350
  for modality_key_from_call, modality_dict_from_call in kwargs_passed_to_call.items():
351
  if modality_key_from_call in final_kwargs and isinstance(modality_dict_from_call, dict):
352
  final_kwargs[modality_key_from_call].update(modality_dict_from_call)
353
  # If a new modality_kwargs (e.g., "video_kwargs") is passed, add it
354
  elif modality_key_from_call not in final_kwargs and isinstance(modality_dict_from_call, dict):
355
- final_kwargs[modality_key_from_call] = modality_dict_from_call.copy()
356
 
357
  # Specific handling for text_kwargs
358
  if "text_kwargs" not in final_kwargs:
359
- final_kwargs["text_kwargs"] = {} # Ensure it exists
360
  final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
361
  final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
362
-
363
  return final_kwargs
364
 
365
  def _compute_audio_prompt_token_count(self, actual_mel_frames_count: int) -> int:
@@ -371,13 +376,13 @@ class Gemma3OmniProcessor(ProcessorMixin):
371
  def __call__(
372
  self,
373
  text: Union[str, List[str]] = None,
374
- images: Optional[Any] = None,
375
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
376
- sampling_rate: Optional[int] = None,
377
- return_tensors: Optional[Union[str, TensorType]] = None,
378
- **kwargs: Any
379
  ) -> BatchFeature:
380
-
381
  if text is None and images is None and audios is None:
382
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
383
 
@@ -385,27 +390,27 @@ class Gemma3OmniProcessor(ProcessorMixin):
385
  # Priority: 1. Explicit return_tensors, 2. from text_kwargs in **kwargs, 3. Default (PT)
386
  final_rt = return_tensors
387
  merged_call_kwargs = self._merge_kwargs(
388
- Gemma3DummyProcessorKwargs, # Using dummy for _defaults structure
389
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
390
- **kwargs
391
  )
392
-
393
- if final_rt is None: # If not passed directly to __call__
394
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
395
- else: # If passed directly, remove from text_kwargs to avoid conflict
396
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
397
 
398
-
399
  if text is None:
400
  num_samples = 0
401
  if images is not None:
402
- _images_list = images if isinstance(images, list) and (not images or not isinstance(images[0], (int, float))) else [images]
 
403
  num_samples = len(_images_list)
404
  elif audios is not None:
405
  _audios_list = audios if isinstance(audios, list) else [audios]
406
  num_samples = len(_audios_list)
407
  text = [""] * num_samples if num_samples > 0 else [""]
408
-
409
  if isinstance(text, str):
410
  text = [text]
411
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
@@ -419,17 +424,16 @@ class Gemma3OmniProcessor(ProcessorMixin):
419
  # text = self._handle_image_text_replacement(text, images, image_features_dict)
420
  pass
421
 
422
-
423
  audio_features_dict = {}
424
  if audios is not None and self.audio_processor is not None:
425
  logger.info("Processing audio...")
426
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
427
- if sampling_rate:
428
- audio_call_kwargs["sampling_rate"] = sampling_rate
429
-
430
  # audio_processor.__call__ returns BatchFeature, we need its .data attribute
431
  audio_features_batch_feature = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
432
- audio_features_dict = audio_features_batch_feature.data # Get the dict
433
 
434
  new_text_with_audio = []
435
  # audio_attention_mask shape is (B, Max_T_mel)
@@ -438,42 +442,42 @@ class Gemma3OmniProcessor(ProcessorMixin):
438
  for i, prompt in enumerate(text):
439
  num_soft_tokens = self._compute_audio_prompt_token_count(audio_sample_mel_lengths[i])
440
  audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens
441
-
442
  if self.audio_placeholder_token in prompt:
443
  prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
444
- else:
445
- prompt += audio_token_sequence_str
446
  new_text_with_audio.append(prompt)
447
  text = new_text_with_audio
448
-
449
  logger.info("Tokenizing text...")
450
  text_call_kwargs = merged_call_kwargs.get("text_kwargs", {})
451
  text_features_dict = self.tokenizer(text, return_tensors=None, **text_call_kwargs)
452
 
453
  input_ids_list = text_features_dict["input_ids"]
454
  if not isinstance(input_ids_list, list) or not (input_ids_list and isinstance(input_ids_list[0], list)):
455
- if isinstance(input_ids_list, (torch.Tensor, np.ndarray)):
456
- input_ids_list = to_py_obj(input_ids_list) # Convert tensor/np.array to list of lists
457
- elif isinstance(input_ids_list, list) and (not input_ids_list or isinstance(input_ids_list[0], int)):
458
- input_ids_list = [input_ids_list]
459
 
460
  token_type_ids_list = []
461
  for ids_sample in input_ids_list:
462
- types = [0] * len(ids_sample)
463
  for j, token_id in enumerate(ids_sample):
464
  if self.image_token_id is not None and token_id == self.image_token_id:
465
- types[j] = 1
466
- elif token_id == self.audio_soft_token_id:
467
- types[j] = 2
468
  token_type_ids_list.append(types)
469
  text_features_dict["token_type_ids"] = token_type_ids_list
470
-
471
  combined_features = {**text_features_dict}
472
- if image_features_dict:
473
  combined_features.update(image_features_dict)
474
- if audio_features_dict:
475
  combined_features.update(audio_features_dict)
476
-
477
  return BatchFeature(data=combined_features, tensor_type=final_rt)
478
 
479
  def batch_decode(self, *args, **kwargs):
@@ -489,6 +493,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
489
  input_names.update(self.image_processor.model_input_names)
490
  if self.audio_processor is not None:
491
  # From Gemma3AudioFeatureExtractor's output_data keys
492
- input_names.update(["audio_values", "audio_attention_mask"])
493
  # "audio_token_calc_sizes" is internal to processor, not model.
494
  return list(input_names)
 
1
  import re
2
+ from typing import List, Optional, Union, Dict, Any, Tuple
3
 
4
  import math
5
  import numpy as np
 
15
 
16
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
17
  from transformers.feature_extraction_utils import BatchFeature
18
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs
19
  from transformers.utils import TensorType, to_py_obj, logging
20
  # For AutoImageProcessor, AutoTokenizer if needed for default loading
21
  from transformers import AutoImageProcessor, AutoTokenizer
22
 
 
23
  # Constants (as defined before)
24
  DEFAULT_SAMPLING_RATE = 16000
25
  DEFAULT_N_FFT = 512
26
+ DEFAULT_WIN_LENGTH = 400 # Will be n_fft if None in __init__
27
+ DEFAULT_HOP_LENGTH = 160 # Will be win_length // 4 if None in __init__
28
  DEFAULT_N_MELS = 80
29
  DEFAULT_COMPRESSION_RATE = 4
30
  DEFAULT_QFORMER_RATE = 2
 
47
  if fmin >= fmax:
48
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
49
 
50
+ def hz_to_mel(f: float) -> float: # Using HTK formula (as in librosa default)
51
  return 2595.0 * math.log10(1 + f / 700.0)
52
 
53
  def mel_to_hz(mel: float) -> float:
54
+ return 700.0 * (10 ** (mel / 2595.0) - 1)
55
 
56
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
57
  freq_points = mel_to_hz(mel_points)
58
+
59
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
60
  bins = np.floor((n_fft / 2.0) * freq_points / (sampling_rate / 2.0)).astype(int)
61
  bins = np.clip(bins, 0, n_fft // 2)
 
63
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
64
  for m in range(n_mels):
65
  left, center, right = bins[m], bins[m + 1], bins[m + 2]
66
+
67
  # Simplified triangle creation logic (more robust versions exist in libraries like librosa)
68
  if center > left:
69
+ filterbank[m, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
70
  if right > center:
71
+ filterbank[m, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
72
  # Ensure peak is 1 if multiple points coincide at center (can happen with narrow filters/low resolution)
73
+ if left <= center <= right and filterbank[m, center] < 1.0 and (
74
+ center > left or center < right): # check if it's a valid point for a peak
75
  # if filterbank[m,center] is not properly set to 1 by slopes (e.g. left==center or right==center)
76
+ filterbank[m, center] = 1.0
77
+ if left == center and right > center: # only falling slope
78
+ # Ensure it doesn't double-dip if already set
79
+ pass
80
+ elif right == center and left < center: # only rising slope
81
+ pass
 
82
 
83
  return filterbank
84
 
85
+
86
  # Gemma3AudioFeatureExtractor class (assuming it's correctly defined from previous response)
87
  # ... (Gemma3AudioFeatureExtractor class from the previous corrected response) ...
88
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
89
+ model_input_names = ["audio_values", "audio_attention_mask"]
90
 
91
  def __init__(
92
  self,
93
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
94
  qformer_rate: int = DEFAULT_QFORMER_RATE,
95
  feat_stride: int = DEFAULT_FEAT_STRIDE,
96
+ sampling_rate: int = DEFAULT_SAMPLING_RATE,
97
  n_fft: int = DEFAULT_N_FFT,
98
+ win_length: Optional[int] = None,
99
+ hop_length: Optional[int] = None,
100
  n_mels: int = DEFAULT_N_MELS,
101
  f_min: float = 0.0,
102
  f_max: Optional[float] = None,
103
+ padding_value: float = 0.0,
104
  **kwargs
105
  ):
106
  super().__init__(feature_size=n_mels, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
 
128
  def __call__(
129
  self,
130
  audios: Union[AudioInput, List[AudioInput]],
131
+ sampling_rate: Optional[int] = None,
132
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
133
  ) -> BatchFeature:
134
+
135
  if not isinstance(audios, list):
136
  audios = [audios]
137
 
138
  processed_mel_spectrograms: List[torch.Tensor] = []
139
  actual_mel_lengths: List[int] = []
140
+ downstream_sizes_for_token_calc: List[torch.Tensor] = []
141
  downstream_frames_scaled_for_token_calc: List[int] = []
142
 
143
  for audio_input_item in audios:
 
161
  )
162
 
163
  processed_wav = self._preprocess_audio(current_wav_array, source_sr)
164
+ mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
165
+
166
  feature_tensor = torch.from_numpy(mel_spectrogram)
167
  processed_mel_spectrograms.append(feature_tensor)
168
+ actual_mel_lengths.append(feature_tensor.shape[0])
169
 
170
  downstream_sizes_for_token_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
171
  downstream_frames_scaled_for_token_calc.append(feature_tensor.shape[0] * self.feat_stride)
 
173
  audio_values = pad_sequence(processed_mel_spectrograms, batch_first=True, padding_value=self.padding_value)
174
  max_mel_len = audio_values.shape[1]
175
  lengths_tensor = torch.tensor(actual_mel_lengths, dtype=torch.long)
176
+ audio_attention_mask = torch.arange(max_mel_len).unsqueeze(0).expand(len(audios),
177
+ -1) < lengths_tensor.unsqueeze(1)
178
+
179
  output_data = {
180
  "audio_values": audio_values,
181
  "audio_attention_mask": audio_attention_mask
182
  }
183
+
184
  if downstream_sizes_for_token_calc:
185
+ output_data["audio_token_calc_sizes"] = torch.stack(downstream_sizes_for_token_calc)
186
+
187
  return BatchFeature(data=output_data, tensor_type=return_tensors)
188
 
189
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
 
191
  if np.issubdtype(wav.dtype, np.integer):
192
  max_val = np.iinfo(wav.dtype).max
193
  wav = wav.astype(np.float32) / max_val
194
+ else:
195
  wav = wav.astype(np.float32)
196
+
197
  if wav.ndim > 1:
198
  wav = wav.mean(axis=0)
199
+
200
  if source_sr != self.sampling_rate:
201
  gcd = math.gcd(self.sampling_rate, source_sr)
202
  up_factor = self.sampling_rate // gcd
203
  down_factor = source_sr // gcd
204
+ if up_factor != down_factor:
205
+ logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
206
+ wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
207
+
208
  norm_factor = np.abs(wav).max()
209
+ if norm_factor > 1e-9:
210
  wav = wav / norm_factor
211
  return wav
212
 
 
217
 
218
  num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
219
  if num_frames <= 0:
220
+ logger.warning(
221
+ f"Audio of length {len(wav)} is too short to produce frames with win_length {self.win_length} and hop_length {self.hop_length}. Returning empty mel spectrogram.")
222
  return np.zeros((0, self.n_mels), dtype=np.float32)
223
 
224
  frames = np.lib.stride_tricks.as_strided(
 
227
  strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
228
  writeable=False
229
  )
230
+
231
  windowed_frames = frames * self.window
232
  stft_matrix = np.fft.rfft(windowed_frames, n=self.n_fft, axis=-1)
233
+ powers = np.abs(stft_matrix) ** 2
234
  mel_spectrogram = np.dot(powers, self.mel_filterbank)
235
  mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
236
  log_mel_spectrogram = np.log(mel_spectrogram)
237
+
238
  return log_mel_spectrogram.astype(np.float32)
239
 
240
  def _calculate_embed_length(self, frame_count: int) -> int:
241
  compressed = math.ceil(frame_count / self.compression_rate)
242
  return math.ceil(compressed / self.qformer_rate)
243
 
244
+
245
+ class Gemma3DummyProcessorKwargs(ProcessingKwargs, total=False): # Dummy for testing structure
246
  images_kwargs: Dict[str, Any]
247
  audio_kwargs: Dict[str, Any]
248
  text_kwargs: Dict[str, Any]
 
252
  "audio_kwargs": {}
253
  }
254
 
255
+
256
  class Gemma3OmniProcessor(ProcessorMixin):
257
  attributes = ["image_processor", "audio_processor", "tokenizer"]
258
  # Define class attributes for ProcessorMixin to find/use them
259
  image_processor_class = "AutoImageProcessor" # Or the specific class string if not auto
260
+ audio_processor_class = Gemma3AudioFeatureExtractor # Correctly points to your custom class
261
+ tokenizer_class = "AutoTokenizer" # Or the specific class string
262
 
263
  # valid_kwargs was in user's code, its role depends on ProcessorMixin internal usage
264
+ valid_kwargs = ["chat_template", "image_seq_length"]
265
 
266
  def __init__(
267
  self,
268
+ tokenizer,
269
  audio_processor: Optional[Union[Gemma3AudioFeatureExtractor, Dict]] = None,
270
+ image_processor=None,
271
  chat_template=None,
272
  image_seq_length: int = 256,
273
+ audio_prompt_compression_rate: int = 8,
274
  audio_prompt_qformer_rate: int = 1,
275
  audio_prompt_feat_stride: int = 1,
276
  audio_placeholder_token: str = "<|audio_placeholder|>",
 
283
  audio_processor = Gemma3AudioFeatureExtractor()
284
  elif isinstance(audio_processor, Dict):
285
  audio_processor = Gemma3AudioFeatureExtractor(**audio_processor)
286
+ elif not isinstance(audio_processor, Gemma3AudioFeatureExtractor): # Check type if instance is passed
287
+ raise TypeError(
288
+ f"audio_processor must be an instance of Gemma3AudioFeatureExtractor or a config dict, got {type(audio_processor)}")
289
 
290
  # Handle image_processor similarly if it can be None or a dict
291
  if image_processor is None and self.image_processor_class:
292
+ # This is a basic way; from_pretrained usually handles complex loading
293
  if isinstance(self.image_processor_class, str) and self.image_processor_class == "AutoImageProcessor":
294
+ logger.info(
295
+ f"Attempting to load a default {self.image_processor_class}. This might require a default model name or fail.")
296
  # image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") # Example default
297
  # else if self.image_processor_class is an actual class, instantiate it.
298
  elif isinstance(image_processor, Dict):
299
  # image_processor = AutoImageProcessor.from_config(config_class(**image_processor)) # Example
300
+ pass # Actual instantiation from dict would be more complex
301
 
302
  # Ensure tokenizer is an instantiated object
303
+ if isinstance(tokenizer, str): # If tokenizer is a string (model name/path)
304
  logger.info(f"Loading tokenizer from {tokenizer}")
305
  # tokenizer = AutoTokenizer.from_pretrained(tokenizer) # This is how it's usually done
306
  elif tokenizer is None:
307
+ raise ValueError("A tokenizer instance or identifier must be provided.")
 
308
 
309
  super().__init__(
310
  image_processor=image_processor,
311
  audio_processor=audio_processor,
312
  tokenizer=tokenizer,
313
  chat_template=chat_template,
314
+ **kwargs # Pass other kwargs to super
315
  )
316
+
317
  self.image_seq_length = image_seq_length
318
+ self.image_token_id = getattr(self.tokenizer, "image_token_id",
319
+ self.tokenizer.unk_token_id if hasattr(self.tokenizer, "unk_token_id") else None)
320
+ self.boi_token = getattr(self.tokenizer, "boi_token", "<|image|>")
321
  self.image_token = getattr(self.tokenizer, "image_token", "<|image|>")
322
+ self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
323
 
324
  self.audio_placeholder_token = audio_placeholder_token
325
  self.audio_soft_token_str = audio_soft_token_str
326
+
327
  self.audio_soft_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_soft_token_str)
328
+ if self.audio_soft_token_id == self.tokenizer.unk_token_id: # Check if UNK
329
+ logger.warning(
330
  f"The audio soft token string '{self.audio_soft_token_str}' maps to UNK token (ID: {self.audio_soft_token_id}). "
331
  "Ensure it is added to the tokenizer's vocabulary as a special token."
332
  )
 
337
  self.audio_prompt_qformer_rate = audio_prompt_qformer_rate
338
  self.audio_prompt_feat_stride = audio_prompt_feat_stride
339
 
 
340
  def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_passed_to_call):
341
  final_kwargs = {}
342
  # Initialize with _defaults from the Kwargs class
 
347
 
348
  # Override with tokenizer's init_kwargs if they exist for a given key
349
  for modality_key, modality_dict in final_kwargs.items():
350
+ for key in list(modality_dict.keys()):
351
  if key in tokenizer_init_kwargs:
352
  modality_dict[key] = tokenizer_init_kwargs[key]
353
+
354
  # Override with kwargs passed directly to __call__
355
  for modality_key_from_call, modality_dict_from_call in kwargs_passed_to_call.items():
356
  if modality_key_from_call in final_kwargs and isinstance(modality_dict_from_call, dict):
357
  final_kwargs[modality_key_from_call].update(modality_dict_from_call)
358
  # If a new modality_kwargs (e.g., "video_kwargs") is passed, add it
359
  elif modality_key_from_call not in final_kwargs and isinstance(modality_dict_from_call, dict):
360
+ final_kwargs[modality_key_from_call] = modality_dict_from_call.copy()
361
 
362
  # Specific handling for text_kwargs
363
  if "text_kwargs" not in final_kwargs:
364
+ final_kwargs["text_kwargs"] = {} # Ensure it exists
365
  final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
366
  final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
367
+
368
  return final_kwargs
369
 
370
  def _compute_audio_prompt_token_count(self, actual_mel_frames_count: int) -> int:
 
376
  def __call__(
377
  self,
378
  text: Union[str, List[str]] = None,
379
+ images: Optional[Any] = None,
380
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
381
+ sampling_rate: Optional[int] = None,
382
+ return_tensors: Optional[Union[str, TensorType]] = None,
383
+ **kwargs: Any
384
  ) -> BatchFeature:
385
+
386
  if text is None and images is None and audios is None:
387
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
388
 
 
390
  # Priority: 1. Explicit return_tensors, 2. from text_kwargs in **kwargs, 3. Default (PT)
391
  final_rt = return_tensors
392
  merged_call_kwargs = self._merge_kwargs(
393
+ Gemma3DummyProcessorKwargs, # Using dummy for _defaults structure
394
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
395
+ **kwargs
396
  )
397
+
398
+ if final_rt is None: # If not passed directly to __call__
399
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
400
+ else: # If passed directly, remove from text_kwargs to avoid conflict
401
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
402
 
 
403
  if text is None:
404
  num_samples = 0
405
  if images is not None:
406
+ _images_list = images if isinstance(images, list) and (
407
+ not images or not isinstance(images[0], (int, float))) else [images]
408
  num_samples = len(_images_list)
409
  elif audios is not None:
410
  _audios_list = audios if isinstance(audios, list) else [audios]
411
  num_samples = len(_audios_list)
412
  text = [""] * num_samples if num_samples > 0 else [""]
413
+
414
  if isinstance(text, str):
415
  text = [text]
416
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
 
424
  # text = self._handle_image_text_replacement(text, images, image_features_dict)
425
  pass
426
 
 
427
  audio_features_dict = {}
428
  if audios is not None and self.audio_processor is not None:
429
  logger.info("Processing audio...")
430
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
431
+ if sampling_rate:
432
+ audio_call_kwargs["sampling_rate"] = sampling_rate
433
+
434
  # audio_processor.__call__ returns BatchFeature, we need its .data attribute
435
  audio_features_batch_feature = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
436
+ audio_features_dict = audio_features_batch_feature.data # Get the dict
437
 
438
  new_text_with_audio = []
439
  # audio_attention_mask shape is (B, Max_T_mel)
 
442
  for i, prompt in enumerate(text):
443
  num_soft_tokens = self._compute_audio_prompt_token_count(audio_sample_mel_lengths[i])
444
  audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens
445
+
446
  if self.audio_placeholder_token in prompt:
447
  prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
448
+ else:
449
+ prompt += audio_token_sequence_str
450
  new_text_with_audio.append(prompt)
451
  text = new_text_with_audio
452
+
453
  logger.info("Tokenizing text...")
454
  text_call_kwargs = merged_call_kwargs.get("text_kwargs", {})
455
  text_features_dict = self.tokenizer(text, return_tensors=None, **text_call_kwargs)
456
 
457
  input_ids_list = text_features_dict["input_ids"]
458
  if not isinstance(input_ids_list, list) or not (input_ids_list and isinstance(input_ids_list[0], list)):
459
+ if isinstance(input_ids_list, (torch.Tensor, np.ndarray)):
460
+ input_ids_list = to_py_obj(input_ids_list) # Convert tensor/np.array to list of lists
461
+ elif isinstance(input_ids_list, list) and (not input_ids_list or isinstance(input_ids_list[0], int)):
462
+ input_ids_list = [input_ids_list]
463
 
464
  token_type_ids_list = []
465
  for ids_sample in input_ids_list:
466
+ types = [0] * len(ids_sample)
467
  for j, token_id in enumerate(ids_sample):
468
  if self.image_token_id is not None and token_id == self.image_token_id:
469
+ types[j] = 1
470
+ elif token_id == self.audio_soft_token_id:
471
+ types[j] = 2
472
  token_type_ids_list.append(types)
473
  text_features_dict["token_type_ids"] = token_type_ids_list
474
+
475
  combined_features = {**text_features_dict}
476
+ if image_features_dict:
477
  combined_features.update(image_features_dict)
478
+ if audio_features_dict:
479
  combined_features.update(audio_features_dict)
480
+
481
  return BatchFeature(data=combined_features, tensor_type=final_rt)
482
 
483
  def batch_decode(self, *args, **kwargs):
 
493
  input_names.update(self.image_processor.model_input_names)
494
  if self.audio_processor is not None:
495
  # From Gemma3AudioFeatureExtractor's output_data keys
496
+ input_names.update(["audio_values", "audio_attention_mask"])
497
  # "audio_token_calc_sizes" is internal to processor, not model.
498
  return list(input_names)