voidful commited on
Commit
315e5b5
·
verified ·
1 Parent(s): 52ca1d3

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +57 -113
processing_gemma3_omni.py CHANGED
@@ -6,10 +6,10 @@ import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
- from transformers.audio_utils import AudioInput # type: ignore
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
- from transformers.image_utils import make_nested_list_of_images # If image processing is used
13
  from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
@@ -19,12 +19,12 @@ DEFAULT_N_FFT = 512
19
  DEFAULT_WIN_LENGTH = 400
20
  DEFAULT_HOP_LENGTH = 160
21
  DEFAULT_N_MELS = 80
22
- DEFAULT_COMPRESSION_RATE = 4 # For _calculate_embed_length
23
- DEFAULT_QFORMER_RATE = 2 # For _calculate_embed_length
24
- DEFAULT_FEAT_STRIDE = 4 # For _calculate_embed_length / 'frames'
25
- IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>" # Not used in this file directly
26
- AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>" # Not used in this file directly
27
- DEFAULT_MAX_LENGTH = 16384 # For tokenizer default
28
  LOG_MEL_CLIP_EPSILON = 1e-5
29
 
30
  logger = logging.get_logger(__name__)
@@ -35,39 +35,35 @@ def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: flo
35
  """Create Mel filterbank for audio processing."""
36
  fmax = fmax or sampling_rate / 2.0
37
 
38
- def hz_to_mel(f: float) -> float: # User's formula
39
  return 1127.0 * math.log(1 + f / 700.0)
40
 
41
  if fmin >= fmax:
42
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
43
 
44
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
45
- freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of user's hz_to_mel
46
 
47
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
48
- bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(
49
- int) # (n_fft+1) or n_fft/2 ? Librosa uses n_fft//2 * hz / sr_nyquist
50
- bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output is n_fft//2
51
 
52
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
53
  for m_idx in range(n_mels):
54
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
55
 
56
- if center > left: # Rising slope
57
  filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
58
- if right > center: # Falling slope
59
- # Need to ensure the peak is 1 if center was part of rising slope
60
- # If left==center, this part creates the full triangle (rising is skipped)
61
  filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
62
 
63
  # Ensure the peak at 'center' is 1.0 if it's a valid point.
64
- # This handles cases where left=center or center=right if the slopes don't perfectly set it.
65
  if left <= center <= right:
66
- if filterbank.shape[1] > center: # Check bounds for center index
67
  if (center > left and filterbank[m_idx, center] < 1.0) or \
68
- (center < right and filterbank[m_idx, center] < 1.0) or \
69
- (left == center and center < right) or \
70
- (right == center and left < center):
71
  filterbank[m_idx, center] = 1.0
72
  return filterbank
73
 
@@ -92,14 +88,14 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
92
  ):
93
  _win_length = win_length if win_length is not None else n_fft
94
  _hop_length = hop_length if hop_length is not None else _win_length // 4
95
-
96
  kwargs.pop("feature_size", None)
97
  kwargs.pop("sampling_rate", None)
98
  kwargs.pop("padding_value", None)
99
-
100
  super().__init__(
101
- feature_size=n_mels, # This is num_mel_bins
102
- sampling_rate=sampling_rate, # This is the target sampling rate for featurization
103
  padding_value=padding_value,
104
  **kwargs
105
  )
@@ -129,7 +125,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
129
  def __call__(
130
  self,
131
  audios: Union[AudioInput, List[AudioInput]],
132
- sampling_rate: Optional[int] = None, # SR of input raw audio arrays
133
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
134
  ) -> BatchFeature:
135
 
@@ -138,19 +134,17 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
138
 
139
  processed_mels: List[torch.Tensor] = []
140
  actual_mel_lengths: List[int] = []
141
-
142
- # These lists are from your original code; their values might be used by Gemma3OmniProcessor later.
143
  sizes_for_downstream_calc: List[torch.Tensor] = []
144
  frames_scaled_for_downstream_calc: List[int] = []
145
 
146
  for audio_item in audios:
147
  current_wav_array: np.ndarray
148
- source_sr: int # Original sampling rate of the current_wav_array
149
 
150
  if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
151
  current_wav_array, source_sr = audio_item
152
  current_wav_array = np.asarray(current_wav_array, dtype=np.float32)
153
- elif isinstance(audio_item, (np.ndarray, list)): # Raw waveform as array/list
154
  current_wav_array = np.asarray(audio_item, dtype=np.float32)
155
  if sampling_rate is None:
156
  raise ValueError(
@@ -159,44 +153,27 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
159
  )
160
  source_sr = sampling_rate
161
  else:
162
- # If you expect to load from paths/bytes, you'd use transformers.audio_utils.load_audio here
163
  raise TypeError(
164
  f"Unsupported audio_item type: {type(audio_item)}. Expected np.ndarray, list of floats, "
165
  "or Tuple[np.ndarray, int (sampling_rate)]."
166
  )
167
 
168
- logger.debug(
169
- f"Gemma3AudioFeatureExtractor: Processing audio item with original shape {current_wav_array.shape}, source_sr {source_sr}")
170
-
171
- # 1. Preprocess: convert to mono, resample to self.sampling_rate, normalize
172
  processed_wav_for_mel = self._preprocess_audio(current_wav_array, source_sr)
173
-
174
- # 2. Compute Log-Mel Spectrogram: results in (NumFrames, self.n_mels)
175
  mel_spectrogram_np = self._compute_log_mel_spectrogram(processed_wav_for_mel)
176
- logger.debug(f"Gemma3AudioFeatureExtractor: Computed mel_spectrogram shape: {mel_spectrogram_np.shape}")
177
 
178
  if not (mel_spectrogram_np.ndim == 2 and mel_spectrogram_np.shape[1] == self.n_mels):
179
- # This check is important if _compute_log_mel_spectrogram could return variable shapes
180
- logger.error(
181
- f"Mel spectrogram computation resulted in unexpected shape {mel_spectrogram_np.shape}. Expected (NumFrames, {self.n_mels})")
182
- # Fallback to a zero-feature tensor of correct feature dimension but zero time, or handle error
183
- # This indicates a problem in _compute_log_mel_spectrogram or very unusual input
184
- # For now, let it proceed, but this would be an issue.
185
- # If num_frames was 0, shape would be (0, n_mels), which is valid.
186
-
187
- feature_tensor = torch.from_numpy(mel_spectrogram_np) # Already float32
188
  processed_mels.append(feature_tensor)
189
- actual_mel_lengths.append(feature_tensor.shape[0]) # Number of time frames
190
 
191
- # Original logic for 'sizes' and 'frames' (kept for compatibility with your processor)
192
  sizes_for_downstream_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
193
  frames_scaled_for_downstream_calc.append(feature_tensor.shape[0] * self.feat_stride)
194
 
195
- # Pad the list of 2D Mel spectrograms to form a 3D batch
196
- # Output shape: (Batch, MaxNumFrames, NumMels)
197
  audio_values_batched = pad_sequence(processed_mels, batch_first=True, padding_value=self.padding_value)
198
-
199
- # Create attention mask for the padded batch
200
  max_t_mel_in_batch = audio_values_batched.shape[1]
201
 
202
  attention_mask_batched = torch.zeros(len(audios), max_t_mel_in_batch, dtype=torch.bool)
@@ -204,15 +181,13 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
204
  attention_mask_batched[i, :length] = True
205
 
206
  output_data = {
207
- "audio_values": audio_values_batched, # Expected by model as (B, T, F)
208
- "audio_attention_mask": attention_mask_batched # Mask for "audio_values"
209
  }
210
 
211
- if sizes_for_downstream_calc: # If these are used by the OmniProcessor
212
  output_data["audio_values_sizes"] = torch.stack(sizes_for_downstream_calc)
213
 
214
- logger.info(
215
- f"Gemma3AudioFeatureExtractor: Final 'audio_values' batch shape: {output_data['audio_values'].shape}")
216
  return BatchFeature(data=output_data, tensor_type=return_tensors)
217
 
218
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
@@ -229,15 +204,14 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
229
  wav = wav.mean(axis=0)
230
 
231
  if source_sr != self.sampling_rate:
232
- # logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
233
  common_divisor = math.gcd(self.sampling_rate, source_sr)
234
  up_factor = self.sampling_rate // common_divisor
235
  down_factor = source_sr // common_divisor
236
- if up_factor != down_factor:
237
  wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
238
 
239
  max_abs_val = np.abs(wav).max()
240
- if max_abs_val > 1e-7: # Avoid division by zero/small numbers for silent/near-silent audio
241
  wav = wav / max_abs_val
242
  return wav
243
 
@@ -249,11 +223,10 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
249
  if len(wav) >= self.win_length:
250
  num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
251
  else:
252
- num_frames = 0 # Should be caught by the padding above, but defensive.
253
 
254
  if num_frames <= 0:
255
- # logger.warning(...)
256
- return np.zeros((0, self.n_mels), dtype=np.float32) # Return shape (0, N_Mels)
257
 
258
  frames_view = np.lib.stride_tricks.as_strided(
259
  wav,
@@ -261,7 +234,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
261
  strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
262
  writeable=False
263
  )
264
- frames_data = frames_view.copy() # Ensure it's a copy before in-place modification
265
  frames_data *= self.window
266
 
267
  spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
@@ -301,7 +274,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
301
  valid_kwargs = ["chat_template", "image_seq_length"]
302
 
303
  image_processor_class = "AutoImageProcessor"
304
- audio_processor_class = "AutoFeatureExtractor" # CRITICAL: Must be string name of your custom class
305
  tokenizer_class = "AutoTokenizer"
306
 
307
  def __init__(
@@ -313,9 +286,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
313
  image_seq_length: int = 256,
314
  **kwargs
315
  ):
316
- # ProcessorMixin.__init__ handles instantiation of audio_processor, image_processor, tokenizer
317
- # if they are None when passed to it, using the *_class attributes defined above.
318
- # If actual instances are passed (e.g., from from_pretrained), they will be used.
319
  super().__init__(
320
  image_processor=image_processor,
321
  audio_processor=audio_processor,
@@ -324,21 +294,16 @@ class Gemma3OmniProcessor(ProcessorMixin):
324
  **kwargs
325
  )
326
 
327
- # These attributes depend on self.tokenizer being properly initialized by super()
328
  self.image_seq_length = image_seq_length
329
  if self.tokenizer is not None:
330
- # Use getattr for robustness, providing defaults if attributes are missing
331
  self.image_token_id = getattr(self.tokenizer, "image_token_id",
332
  self.tokenizer.unk_token_id if hasattr(self.tokenizer,
333
  "unk_token_id") else None)
334
- self.boi_token = getattr(self.tokenizer, "boi_token", "<image>") # More common default
335
  self.image_token = getattr(self.tokenizer, "image_token", "<image>")
336
- self.eoi_token = getattr(self.tokenizer, "eoi_token", "") # Default to empty if not present
337
 
338
- # User's original attributes for audio tokens
339
  self.audio_token_str_from_user_code = "<audio_soft_token>"
340
- # self.expected_audio_token_id = 262143 # User's reference, keep commented for minimal change
341
-
342
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
343
  if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
344
  logger.warning(
@@ -347,7 +312,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
347
  )
348
  self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token}\n\n"
349
  else:
350
- # This state (tokenizer is None after super init) should ideally not occur if from_pretrained works.
351
  logger.error(
352
  "Gemma3OmniProcessor initialized, but self.tokenizer is None. Token-dependent attributes will use placeholders or defaults.")
353
  self.image_token_id = None
@@ -355,17 +319,15 @@ class Gemma3OmniProcessor(ProcessorMixin):
355
  self.image_token = "<image>"
356
  self.eoi_token = ""
357
  self.audio_token_str_from_user_code = "<audio_soft_token>"
358
- self.audio_token_id = -1 # Placeholder
359
  self.full_image_sequence = ""
360
 
361
- # These are parameters for this processor's logic for number of audio tokens in prompt
362
  self.prompt_audio_compression_rate = kwargs.pop("audio_prompt_compression_rate", 8)
363
  self.prompt_audio_qformer_rate = kwargs.pop("audio_prompt_qformer_rate", 1)
364
  self.prompt_audio_feat_stride = kwargs.pop("audio_prompt_feat_stride", 1)
365
  self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")
366
 
367
  def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
368
- # This method merges default kwargs, tokenizer init kwargs, and call-specific kwargs
369
  final_kwargs = {}
370
  _defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
371
  if not isinstance(_defaults, dict): _defaults = {}
@@ -374,20 +336,20 @@ class Gemma3OmniProcessor(ProcessorMixin):
374
  final_kwargs[modality_key] = default_modality_kwargs.copy()
375
 
376
  for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
377
- if modality_key_in_call in final_kwargs: # e.g. "text_kwargs"
378
  if isinstance(modality_kwargs_in_call, dict):
379
  final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
380
- elif isinstance(modality_kwargs_in_call, dict): # New modality not in _defaults (e.g. "video_kwargs")
381
  final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
382
 
383
- if self.tokenizer: # Ensure tokenizer is available for its init_kwargs
384
  for modality_key in final_kwargs:
385
  modality_dict = final_kwargs[modality_key]
386
  if isinstance(modality_dict, dict):
387
  for key_in_mod_dict in list(modality_dict.keys()):
388
- if key_in_mod_dict in tokenizer_init_kwargs: # tokenizer_init_kwargs from self.tokenizer.init_kwargs
389
  value = (
390
- getattr(self.tokenizer, key_in_mod_dict) # Check actual tokenizer attribute first
391
  if hasattr(self.tokenizer, key_in_mod_dict)
392
  else tokenizer_init_kwargs[key_in_mod_dict]
393
  )
@@ -395,7 +357,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
395
 
396
  if "text_kwargs" not in final_kwargs:
397
  final_kwargs["text_kwargs"] = {}
398
- # Ensure these text_kwargs have defaults if not set otherwise
399
  final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
400
  final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
401
 
@@ -418,23 +379,19 @@ class Gemma3OmniProcessor(ProcessorMixin):
418
  if text is None and images is None and audios is None:
419
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
420
 
421
- # Determine final return_tensors strategy (explicit __call__ arg > from text_kwargs > default)
422
  final_rt = return_tensors
423
- # _merge_kwargs uses Gemma3ProcessorKwargs to structure the **kwargs from __call__
424
  merged_call_kwargs = self._merge_kwargs(
425
- Gemma3ProcessorKwargs, # Class defining _defaults structure
426
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
427
  **kwargs
428
  )
429
 
430
- if final_rt is None: # If not passed directly to __call__
431
- # Get from merged_call_kwargs (which would have picked it up from kwargs['text_kwargs'])
432
- # and remove it to prevent passing twice to tokenizer
433
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
434
- else: # If passed directly, ensure it's removed from text_kwargs to avoid conflict
435
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
436
 
437
- if text is None: # Default text if only other modalities are provided
438
  num_samples = 0
439
  if images is not None:
440
  _images_list = images if isinstance(images, list) and (
@@ -443,13 +400,12 @@ class Gemma3OmniProcessor(ProcessorMixin):
443
  elif audios is not None:
444
  _audios_list = audios if isinstance(audios, list) else [audios]
445
  num_samples = len(_audios_list)
446
- text = [""] * num_samples if num_samples > 0 else [""] # Create empty strings or one if no samples
447
 
448
  if isinstance(text, str): text = [text]
449
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
450
  raise ValueError("Input `text` must be a string or a list of strings.")
451
 
452
- # --- Image Processing (User's structure) ---
453
  image_features_dict = {}
454
  if images is not None:
455
  if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
@@ -459,7 +415,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
459
  image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
460
  BatchFeature) else _img_proc_output
461
 
462
- # Adjust text based on images (user's original logic)
463
  if len(text) == 0 and len(batched_images) > 0: text = [" ".join([self.boi_token] * len(img_batch)) for
464
  img_batch in batched_images]
465
  if len(batched_images) != len(text): raise ValueError(
@@ -496,18 +451,14 @@ class Gemma3OmniProcessor(ProcessorMixin):
496
  text = temp_text_img
497
  text = [p.replace(self.boi_token, self.full_image_sequence) for p in text]
498
 
499
- # --- Audio Processing ---
500
  audio_features_dict = {}
501
  if audios is not None:
502
  if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")
503
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
504
- if sampling_rate is not None: audio_call_kwargs[
505
- "sampling_rate"] = sampling_rate # Pass SR to feature extractor
506
 
507
  _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
508
  audio_features_dict = _audio_proc_output.data
509
- logger.info(
510
- f"Gemma3OmniProcessor: 'audio_values' shape from Feature Extractor: {audio_features_dict['audio_values'].shape}")
511
 
512
  new_text_with_audio, actual_mel_frames_per_sample = [], to_py_obj(
513
  audio_features_dict["audio_attention_mask"].sum(axis=-1))
@@ -516,9 +467,8 @@ class Gemma3OmniProcessor(ProcessorMixin):
516
 
517
  for i, prompt in enumerate(text):
518
  num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
519
- audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens # e.g. "<audio_soft_token>" * N
520
 
521
- # User's original boa_token for replacement was " ", which is risky. Using defined placeholder.
522
  if self.audio_placeholder_token in prompt:
523
  prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
524
  else:
@@ -526,10 +476,9 @@ class Gemma3OmniProcessor(ProcessorMixin):
526
  new_text_with_audio.append(prompt)
527
  text = new_text_with_audio
528
 
529
- # --- Text Tokenization ---
530
  text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
531
  text_features_dict = self.tokenizer(text=text, return_tensors=None,
532
- **text_tokenizer_kwargs) # Get lists/np.arrays
533
 
534
  input_ids_list_of_lists = text_features_dict["input_ids"]
535
  if not isinstance(input_ids_list_of_lists, list) or not (
@@ -551,11 +500,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
551
  token_type_ids_list.append(types)
552
  text_features_dict["token_type_ids"] = token_type_ids_list
553
 
554
- # Ensure text_features_dict also has 'attention_mask' if tokenizer applied padding
555
- # If tokenizer was called with padding=True/strategy, it would add 'attention_mask'
556
- # If called with padding=False (default), 'attention_mask' might be missing or all 1s.
557
- # BatchFeature will handle final tensor conversion and padding based on final_rt.
558
-
559
  final_batch_data = {**text_features_dict}
560
  if image_features_dict: final_batch_data.update(image_features_dict)
561
  if audio_features_dict: final_batch_data.update(audio_features_dict)
@@ -581,7 +525,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
581
  hasattr(self.audio_processor, 'model_input_names'):
582
  input_names.update(self.audio_processor.model_input_names)
583
  elif hasattr(self,
584
- 'audio_processor') and self.audio_processor is not None: # Fallback if model_input_names not on custom audio_processor
585
  input_names.update(["audio_values", "audio_attention_mask"])
586
 
587
  return list(input_names)
 
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
+ from transformers.audio_utils import AudioInput # type: ignore
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
+ from transformers.image_utils import make_nested_list_of_images # If image processing is used
13
  from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
 
19
  DEFAULT_WIN_LENGTH = 400
20
  DEFAULT_HOP_LENGTH = 160
21
  DEFAULT_N_MELS = 80
22
+ DEFAULT_COMPRESSION_RATE = 4
23
+ DEFAULT_QFORMER_RATE = 2
24
+ DEFAULT_FEAT_STRIDE = 4
25
+ IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
+ AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
+ DEFAULT_MAX_LENGTH = 16384
28
  LOG_MEL_CLIP_EPSILON = 1e-5
29
 
30
  logger = logging.get_logger(__name__)
 
35
  """Create Mel filterbank for audio processing."""
36
  fmax = fmax or sampling_rate / 2.0
37
 
38
+ def hz_to_mel(f: float) -> float:
39
  return 1127.0 * math.log(1 + f / 700.0)
40
 
41
  if fmin >= fmax:
42
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
43
 
44
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
45
+ freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of user's hz_to_mel
46
 
47
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
48
+ bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
49
+ bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output is n_fft//2
 
50
 
51
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
52
  for m_idx in range(n_mels):
53
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
54
 
55
+ if center > left: # Rising slope
56
  filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
57
+ if right > center: # Falling slope
 
 
58
  filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
59
 
60
  # Ensure the peak at 'center' is 1.0 if it's a valid point.
 
61
  if left <= center <= right:
62
+ if filterbank.shape[1] > center:
63
  if (center > left and filterbank[m_idx, center] < 1.0) or \
64
+ (center < right and filterbank[m_idx, center] < 1.0) or \
65
+ (left == center and center < right) or \
66
+ (right == center and left < center):
67
  filterbank[m_idx, center] = 1.0
68
  return filterbank
69
 
 
88
  ):
89
  _win_length = win_length if win_length is not None else n_fft
90
  _hop_length = hop_length if hop_length is not None else _win_length // 4
91
+
92
  kwargs.pop("feature_size", None)
93
  kwargs.pop("sampling_rate", None)
94
  kwargs.pop("padding_value", None)
95
+
96
  super().__init__(
97
+ feature_size=n_mels,
98
+ sampling_rate=sampling_rate,
99
  padding_value=padding_value,
100
  **kwargs
101
  )
 
125
  def __call__(
126
  self,
127
  audios: Union[AudioInput, List[AudioInput]],
128
+ sampling_rate: Optional[int] = None,
129
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
130
  ) -> BatchFeature:
131
 
 
134
 
135
  processed_mels: List[torch.Tensor] = []
136
  actual_mel_lengths: List[int] = []
 
 
137
  sizes_for_downstream_calc: List[torch.Tensor] = []
138
  frames_scaled_for_downstream_calc: List[int] = []
139
 
140
  for audio_item in audios:
141
  current_wav_array: np.ndarray
142
+ source_sr: int
143
 
144
  if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
145
  current_wav_array, source_sr = audio_item
146
  current_wav_array = np.asarray(current_wav_array, dtype=np.float32)
147
+ elif isinstance(audio_item, (np.ndarray, list)):
148
  current_wav_array = np.asarray(audio_item, dtype=np.float32)
149
  if sampling_rate is None:
150
  raise ValueError(
 
153
  )
154
  source_sr = sampling_rate
155
  else:
 
156
  raise TypeError(
157
  f"Unsupported audio_item type: {type(audio_item)}. Expected np.ndarray, list of floats, "
158
  "or Tuple[np.ndarray, int (sampling_rate)]."
159
  )
160
 
 
 
 
 
161
  processed_wav_for_mel = self._preprocess_audio(current_wav_array, source_sr)
 
 
162
  mel_spectrogram_np = self._compute_log_mel_spectrogram(processed_wav_for_mel)
 
163
 
164
  if not (mel_spectrogram_np.ndim == 2 and mel_spectrogram_np.shape[1] == self.n_mels):
165
+ # This could indicate an issue in _compute_log_mel_spectrogram or very unusual input.
166
+ # Depending on downstream requirements, this might need more robust error handling or a clear fallback.
167
+ pass # Allowing to proceed, but output shape might be unexpected.
168
+
169
+ feature_tensor = torch.from_numpy(mel_spectrogram_np)
 
 
 
 
170
  processed_mels.append(feature_tensor)
171
+ actual_mel_lengths.append(feature_tensor.shape[0])
172
 
 
173
  sizes_for_downstream_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
174
  frames_scaled_for_downstream_calc.append(feature_tensor.shape[0] * self.feat_stride)
175
 
 
 
176
  audio_values_batched = pad_sequence(processed_mels, batch_first=True, padding_value=self.padding_value)
 
 
177
  max_t_mel_in_batch = audio_values_batched.shape[1]
178
 
179
  attention_mask_batched = torch.zeros(len(audios), max_t_mel_in_batch, dtype=torch.bool)
 
181
  attention_mask_batched[i, :length] = True
182
 
183
  output_data = {
184
+ "audio_values": audio_values_batched,
185
+ "audio_attention_mask": attention_mask_batched
186
  }
187
 
188
+ if sizes_for_downstream_calc:
189
  output_data["audio_values_sizes"] = torch.stack(sizes_for_downstream_calc)
190
 
 
 
191
  return BatchFeature(data=output_data, tensor_type=return_tensors)
192
 
193
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
 
204
  wav = wav.mean(axis=0)
205
 
206
  if source_sr != self.sampling_rate:
 
207
  common_divisor = math.gcd(self.sampling_rate, source_sr)
208
  up_factor = self.sampling_rate // common_divisor
209
  down_factor = source_sr // common_divisor
210
+ if up_factor != down_factor: # Avoid resampling if factors are identical
211
  wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
212
 
213
  max_abs_val = np.abs(wav).max()
214
+ if max_abs_val > 1e-7:
215
  wav = wav / max_abs_val
216
  return wav
217
 
 
223
  if len(wav) >= self.win_length:
224
  num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
225
  else:
226
+ num_frames = 0
227
 
228
  if num_frames <= 0:
229
+ return np.zeros((0, self.n_mels), dtype=np.float32) # Return shape (0, N_Mels)
 
230
 
231
  frames_view = np.lib.stride_tricks.as_strided(
232
  wav,
 
234
  strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
235
  writeable=False
236
  )
237
+ frames_data = frames_view.copy()
238
  frames_data *= self.window
239
 
240
  spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
 
274
  valid_kwargs = ["chat_template", "image_seq_length"]
275
 
276
  image_processor_class = "AutoImageProcessor"
277
+ audio_processor_class = "AutoFeatureExtractor"
278
  tokenizer_class = "AutoTokenizer"
279
 
280
  def __init__(
 
286
  image_seq_length: int = 256,
287
  **kwargs
288
  ):
 
 
 
289
  super().__init__(
290
  image_processor=image_processor,
291
  audio_processor=audio_processor,
 
294
  **kwargs
295
  )
296
 
 
297
  self.image_seq_length = image_seq_length
298
  if self.tokenizer is not None:
 
299
  self.image_token_id = getattr(self.tokenizer, "image_token_id",
300
  self.tokenizer.unk_token_id if hasattr(self.tokenizer,
301
  "unk_token_id") else None)
302
+ self.boi_token = getattr(self.tokenizer, "boi_token", "<image>")
303
  self.image_token = getattr(self.tokenizer, "image_token", "<image>")
304
+ self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
305
 
 
306
  self.audio_token_str_from_user_code = "<audio_soft_token>"
 
 
307
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
308
  if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
309
  logger.warning(
 
312
  )
313
  self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token}\n\n"
314
  else:
 
315
  logger.error(
316
  "Gemma3OmniProcessor initialized, but self.tokenizer is None. Token-dependent attributes will use placeholders or defaults.")
317
  self.image_token_id = None
 
319
  self.image_token = "<image>"
320
  self.eoi_token = ""
321
  self.audio_token_str_from_user_code = "<audio_soft_token>"
322
+ self.audio_token_id = -1
323
  self.full_image_sequence = ""
324
 
 
325
  self.prompt_audio_compression_rate = kwargs.pop("audio_prompt_compression_rate", 8)
326
  self.prompt_audio_qformer_rate = kwargs.pop("audio_prompt_qformer_rate", 1)
327
  self.prompt_audio_feat_stride = kwargs.pop("audio_prompt_feat_stride", 1)
328
  self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")
329
 
330
  def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
 
331
  final_kwargs = {}
332
  _defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
333
  if not isinstance(_defaults, dict): _defaults = {}
 
336
  final_kwargs[modality_key] = default_modality_kwargs.copy()
337
 
338
  for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
339
+ if modality_key_in_call in final_kwargs:
340
  if isinstance(modality_kwargs_in_call, dict):
341
  final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
342
+ elif isinstance(modality_kwargs_in_call, dict):
343
  final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
344
 
345
+ if self.tokenizer:
346
  for modality_key in final_kwargs:
347
  modality_dict = final_kwargs[modality_key]
348
  if isinstance(modality_dict, dict):
349
  for key_in_mod_dict in list(modality_dict.keys()):
350
+ if key_in_mod_dict in tokenizer_init_kwargs:
351
  value = (
352
+ getattr(self.tokenizer, key_in_mod_dict)
353
  if hasattr(self.tokenizer, key_in_mod_dict)
354
  else tokenizer_init_kwargs[key_in_mod_dict]
355
  )
 
357
 
358
  if "text_kwargs" not in final_kwargs:
359
  final_kwargs["text_kwargs"] = {}
 
360
  final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
361
  final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
362
 
 
379
  if text is None and images is None and audios is None:
380
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
381
 
 
382
  final_rt = return_tensors
 
383
  merged_call_kwargs = self._merge_kwargs(
384
+ Gemma3ProcessorKwargs,
385
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
386
  **kwargs
387
  )
388
 
389
+ if final_rt is None:
 
 
390
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
391
+ else:
392
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
393
 
394
+ if text is None:
395
  num_samples = 0
396
  if images is not None:
397
  _images_list = images if isinstance(images, list) and (
 
400
  elif audios is not None:
401
  _audios_list = audios if isinstance(audios, list) else [audios]
402
  num_samples = len(_audios_list)
403
+ text = [""] * num_samples if num_samples > 0 else [""]
404
 
405
  if isinstance(text, str): text = [text]
406
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
407
  raise ValueError("Input `text` must be a string or a list of strings.")
408
 
 
409
  image_features_dict = {}
410
  if images is not None:
411
  if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
 
415
  image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
416
  BatchFeature) else _img_proc_output
417
 
 
418
  if len(text) == 0 and len(batched_images) > 0: text = [" ".join([self.boi_token] * len(img_batch)) for
419
  img_batch in batched_images]
420
  if len(batched_images) != len(text): raise ValueError(
 
451
  text = temp_text_img
452
  text = [p.replace(self.boi_token, self.full_image_sequence) for p in text]
453
 
 
454
  audio_features_dict = {}
455
  if audios is not None:
456
  if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")
457
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
458
+ if sampling_rate is not None: audio_call_kwargs["sampling_rate"] = sampling_rate
 
459
 
460
  _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
461
  audio_features_dict = _audio_proc_output.data
 
 
462
 
463
  new_text_with_audio, actual_mel_frames_per_sample = [], to_py_obj(
464
  audio_features_dict["audio_attention_mask"].sum(axis=-1))
 
467
 
468
  for i, prompt in enumerate(text):
469
  num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
470
+ audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens
471
 
 
472
  if self.audio_placeholder_token in prompt:
473
  prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
474
  else:
 
476
  new_text_with_audio.append(prompt)
477
  text = new_text_with_audio
478
 
 
479
  text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
480
  text_features_dict = self.tokenizer(text=text, return_tensors=None,
481
+ **text_tokenizer_kwargs)
482
 
483
  input_ids_list_of_lists = text_features_dict["input_ids"]
484
  if not isinstance(input_ids_list_of_lists, list) or not (
 
500
  token_type_ids_list.append(types)
501
  text_features_dict["token_type_ids"] = token_type_ids_list
502
 
 
 
 
 
 
503
  final_batch_data = {**text_features_dict}
504
  if image_features_dict: final_batch_data.update(image_features_dict)
505
  if audio_features_dict: final_batch_data.update(audio_features_dict)
 
525
  hasattr(self.audio_processor, 'model_input_names'):
526
  input_names.update(self.audio_processor.model_input_names)
527
  elif hasattr(self,
528
+ 'audio_processor') and self.audio_processor is not None:
529
  input_names.update(["audio_values", "audio_attention_mask"])
530
 
531
  return list(input_names)