voidful commited on
Commit
3fa62c9
·
verified ·
1 Parent(s): 315e5b5

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +347 -250
processing_gemma3_omni.py CHANGED
@@ -1,253 +1,263 @@
1
  import re
2
- from typing import List, Optional, Union, Dict, Any
3
 
4
  import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
- from transformers.audio_utils import AudioInput # type: ignore
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
- from transformers.image_utils import make_nested_list_of_images # If image processing is used
13
  from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
16
  # Constants
17
  DEFAULT_SAMPLING_RATE = 16000
18
  DEFAULT_N_FFT = 512
19
- DEFAULT_WIN_LENGTH = 400
20
- DEFAULT_HOP_LENGTH = 160
21
  DEFAULT_N_MELS = 80
22
- DEFAULT_COMPRESSION_RATE = 4
23
- DEFAULT_QFORMER_RATE = 2
24
- DEFAULT_FEAT_STRIDE = 4
25
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
  DEFAULT_MAX_LENGTH = 16384
28
- LOG_MEL_CLIP_EPSILON = 1e-5
29
 
30
  logger = logging.get_logger(__name__)
31
 
32
 
 
 
33
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
34
  fmax: Optional[float] = None) -> np.ndarray:
35
  """Create Mel filterbank for audio processing."""
36
  fmax = fmax or sampling_rate / 2.0
37
 
38
- def hz_to_mel(f: float) -> float:
39
  return 1127.0 * math.log(1 + f / 700.0)
40
 
41
  if fmin >= fmax:
42
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
43
 
44
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
45
- freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of user's hz_to_mel
46
 
47
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
48
  bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
49
- bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output is n_fft//2
50
 
51
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
52
  for m_idx in range(n_mels):
53
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
54
 
55
- if center > left: # Rising slope
56
  filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
57
- if right > center: # Falling slope
58
  filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
59
 
60
  # Ensure the peak at 'center' is 1.0 if it's a valid point.
61
- if left <= center <= right:
62
- if filterbank.shape[1] > center:
63
- if (center > left and filterbank[m_idx, center] < 1.0) or \
64
- (center < right and filterbank[m_idx, center] < 1.0) or \
65
- (left == center and center < right) or \
66
- (right == center and left < center):
67
  filterbank[m_idx, center] = 1.0
 
 
 
68
  return filterbank
69
 
70
 
 
71
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
72
- model_input_names = ["audio_values", "audio_attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- def __init__(
75
  self,
76
- compression_rate: int = DEFAULT_COMPRESSION_RATE,
77
- qformer_rate: int = DEFAULT_QFORMER_RATE,
78
- feat_stride: int = DEFAULT_FEAT_STRIDE,
79
- sampling_rate: int = DEFAULT_SAMPLING_RATE,
80
- n_fft: int = DEFAULT_N_FFT,
81
- win_length: Optional[int] = None,
82
- hop_length: Optional[int] = None,
83
- n_mels: int = DEFAULT_N_MELS,
84
- f_min: float = 0.0,
85
- f_max: Optional[float] = None,
86
- padding_value: float = 0.0,
87
- **kwargs
88
  ):
89
- _win_length = win_length if win_length is not None else n_fft
90
- _hop_length = hop_length if hop_length is not None else _win_length // 4
 
 
 
 
 
 
 
 
91
 
92
- kwargs.pop("feature_size", None)
93
- kwargs.pop("sampling_rate", None)
94
- kwargs.pop("padding_value", None)
 
95
 
96
- super().__init__(
97
- feature_size=n_mels,
98
- sampling_rate=sampling_rate,
99
- padding_value=padding_value,
100
- **kwargs
101
- )
102
 
103
- self.compression_rate = compression_rate
104
- self.qformer_rate = qformer_rate
105
- self.feat_stride = feat_stride
106
- # self.sampling_rate is set by super() to the target rate
107
-
108
- self.n_fft = n_fft
109
- self.win_length = _win_length
110
- self.hop_length = _hop_length
111
- self.n_mels = n_mels
112
- self.f_min = f_min
113
- self.f_max = f_max if f_max is not None else self.sampling_rate / 2.0
114
-
115
- if self.win_length > self.n_fft:
116
- logger.warning(
117
- f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
118
- "Window will be applied, then data zero-padded/truncated to n_fft by np.fft.rfft."
119
- )
120
- self.window = np.hamming(self.win_length).astype(np.float32)
121
- self.mel_filterbank = create_mel_filterbank(
122
- self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
123
- ).T
124
 
125
- def __call__(
126
- self,
127
- audios: Union[AudioInput, List[AudioInput]],
128
- sampling_rate: Optional[int] = None,
129
- return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
130
- ) -> BatchFeature:
131
 
132
- if not isinstance(audios, list):
133
- audios = [audios]
134
-
135
- processed_mels: List[torch.Tensor] = []
136
- actual_mel_lengths: List[int] = []
137
- sizes_for_downstream_calc: List[torch.Tensor] = []
138
- frames_scaled_for_downstream_calc: List[int] = []
139
-
140
- for audio_item in audios:
141
- current_wav_array: np.ndarray
142
- source_sr: int
143
-
144
- if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
145
- current_wav_array, source_sr = audio_item
146
- current_wav_array = np.asarray(current_wav_array, dtype=np.float32)
147
- elif isinstance(audio_item, (np.ndarray, list)):
148
- current_wav_array = np.asarray(audio_item, dtype=np.float32)
149
- if sampling_rate is None:
150
- raise ValueError(
151
- "sampling_rate argument must be provided to __call__ if 'audios' items "
152
- "are raw numpy arrays or lists (without embedded sampling rate info)."
153
- )
154
- source_sr = sampling_rate
155
- else:
156
- raise TypeError(
157
- f"Unsupported audio_item type: {type(audio_item)}. Expected np.ndarray, list of floats, "
158
- "or Tuple[np.ndarray, int (sampling_rate)]."
159
- )
160
 
161
- processed_wav_for_mel = self._preprocess_audio(current_wav_array, source_sr)
162
- mel_spectrogram_np = self._compute_log_mel_spectrogram(processed_wav_for_mel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- if not (mel_spectrogram_np.ndim == 2 and mel_spectrogram_np.shape[1] == self.n_mels):
165
- # This could indicate an issue in _compute_log_mel_spectrogram or very unusual input.
166
- # Depending on downstream requirements, this might need more robust error handling or a clear fallback.
167
- pass # Allowing to proceed, but output shape might be unexpected.
168
 
169
- feature_tensor = torch.from_numpy(mel_spectrogram_np)
170
- processed_mels.append(feature_tensor)
171
- actual_mel_lengths.append(feature_tensor.shape[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- sizes_for_downstream_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
174
- frames_scaled_for_downstream_calc.append(feature_tensor.shape[0] * self.feat_stride)
175
 
176
- audio_values_batched = pad_sequence(processed_mels, batch_first=True, padding_value=self.padding_value)
177
- max_t_mel_in_batch = audio_values_batched.shape[1]
 
178
 
179
- attention_mask_batched = torch.zeros(len(audios), max_t_mel_in_batch, dtype=torch.bool)
180
- for i, length in enumerate(actual_mel_lengths):
181
- attention_mask_batched[i, :length] = True
 
182
 
183
- output_data = {
184
- "audio_values": audio_values_batched,
185
- "audio_attention_mask": attention_mask_batched
186
- }
187
 
188
- if sizes_for_downstream_calc:
189
- output_data["audio_values_sizes"] = torch.stack(sizes_for_downstream_calc)
190
 
191
- return BatchFeature(data=output_data, tensor_type=return_tensors)
 
 
 
 
 
 
 
 
 
192
 
193
- def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
194
- if wav.dtype not in [np.float32, np.float64]:
195
- if np.issubdtype(wav.dtype, np.integer):
196
- max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0
197
- wav = wav.astype(np.float32) / max_val
198
- else:
199
- wav = wav.astype(np.float32)
200
- elif wav.dtype == np.float64:
201
- wav = wav.astype(np.float32)
202
 
203
- if wav.ndim > 1:
204
- wav = wav.mean(axis=0)
205
-
206
- if source_sr != self.sampling_rate:
207
- common_divisor = math.gcd(self.sampling_rate, source_sr)
208
- up_factor = self.sampling_rate // common_divisor
209
- down_factor = source_sr // common_divisor
210
- if up_factor != down_factor: # Avoid resampling if factors are identical
211
- wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
212
-
213
- max_abs_val = np.abs(wav).max()
214
- if max_abs_val > 1e-7:
215
- wav = wav / max_abs_val
216
- return wav
217
-
218
- def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
219
- if len(wav) < self.win_length:
220
- padding = self.win_length - len(wav)
221
- wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
222
-
223
- if len(wav) >= self.win_length:
224
- num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
225
- else:
226
- num_frames = 0
227
 
228
- if num_frames <= 0:
229
- return np.zeros((0, self.n_mels), dtype=np.float32) # Return shape (0, N_Mels)
 
 
230
 
231
- frames_view = np.lib.stride_tricks.as_strided(
232
- wav,
233
- shape=(num_frames, self.win_length),
234
- strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
235
- writeable=False
236
- )
237
- frames_data = frames_view.copy()
238
- frames_data *= self.window
239
 
240
- spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
241
- power = np.abs(spectrum) ** 2
242
- mel_spectrogram = np.dot(power, self.mel_filterbank)
243
- mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
244
- log_mel_spectrogram = np.log(mel_spectrogram)
245
 
246
- return log_mel_spectrogram.astype(np.float32)
247
 
248
- def _calculate_embed_length(self, frame_count: int) -> int:
249
- compressed = math.ceil(frame_count / self.compression_rate)
250
- return math.ceil(compressed / self.qformer_rate)
251
 
252
 
253
  class Gemma3ImagesKwargs(ImagesKwargs):
@@ -280,7 +290,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
280
  def __init__(
281
  self,
282
  image_processor=None,
283
- audio_processor=None,
284
  tokenizer=None,
285
  chat_template=None,
286
  image_seq_length: int = 256,
@@ -303,7 +313,8 @@ class Gemma3OmniProcessor(ProcessorMixin):
303
  self.image_token = getattr(self.tokenizer, "image_token", "<image>")
304
  self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
305
 
306
- self.audio_token_str_from_user_code = "<audio_soft_token>"
 
307
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
308
  if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
309
  logger.warning(
@@ -319,12 +330,14 @@ class Gemma3OmniProcessor(ProcessorMixin):
319
  self.image_token = "<image>"
320
  self.eoi_token = ""
321
  self.audio_token_str_from_user_code = "<audio_soft_token>"
322
- self.audio_token_id = -1
323
  self.full_image_sequence = ""
324
 
325
- self.prompt_audio_compression_rate = kwargs.pop("audio_prompt_compression_rate", 8)
326
- self.prompt_audio_qformer_rate = kwargs.pop("audio_prompt_qformer_rate", 1)
327
- self.prompt_audio_feat_stride = kwargs.pop("audio_prompt_feat_stride", 1)
 
 
328
  self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")
329
 
330
  def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
@@ -339,14 +352,14 @@ class Gemma3OmniProcessor(ProcessorMixin):
339
  if modality_key_in_call in final_kwargs:
340
  if isinstance(modality_kwargs_in_call, dict):
341
  final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
342
- elif isinstance(modality_kwargs_in_call, dict):
343
  final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
344
 
345
- if self.tokenizer:
346
  for modality_key in final_kwargs:
347
  modality_dict = final_kwargs[modality_key]
348
- if isinstance(modality_dict, dict):
349
- for key_in_mod_dict in list(modality_dict.keys()):
350
  if key_in_mod_dict in tokenizer_init_kwargs:
351
  value = (
352
  getattr(self.tokenizer, key_in_mod_dict)
@@ -355,148 +368,218 @@ class Gemma3OmniProcessor(ProcessorMixin):
355
  )
356
  modality_dict[key_in_mod_dict] = value
357
 
358
- if "text_kwargs" not in final_kwargs:
359
- final_kwargs["text_kwargs"] = {}
360
  final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
361
  final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
362
 
363
  return final_kwargs
364
 
365
  def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
366
- scaled_frames = audio_mel_frames * self.prompt_audio_feat_stride
367
- result = math.ceil(scaled_frames / self.prompt_audio_compression_rate)
368
- return math.ceil(result / self.prompt_audio_qformer_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  def __call__(
371
  self,
372
  text: Union[str, List[str]] = None,
373
  images: Optional[Any] = None,
374
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
375
- sampling_rate: Optional[int] = None,
376
  return_tensors: Optional[Union[str, TensorType]] = None,
377
  **kwargs: Any
378
  ) -> BatchFeature:
379
  if text is None and images is None and audios is None:
380
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
381
 
382
- final_rt = return_tensors
 
383
  merged_call_kwargs = self._merge_kwargs(
384
- Gemma3ProcessorKwargs,
385
- self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
386
- **kwargs
387
  )
388
 
389
- if final_rt is None:
 
390
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
391
- else:
392
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
393
 
394
- if text is None:
395
  num_samples = 0
396
  if images is not None:
397
  _images_list = images if isinstance(images, list) and (
398
  not images or not isinstance(images[0], (int, float))) else [images]
399
  num_samples = len(_images_list)
400
  elif audios is not None:
401
- _audios_list = audios if isinstance(audios, list) else [audios]
 
 
402
  num_samples = len(_audios_list)
403
- text = [""] * num_samples if num_samples > 0 else [""]
404
 
405
- if isinstance(text, str): text = [text]
406
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
407
  raise ValueError("Input `text` must be a string or a list of strings.")
408
 
409
  image_features_dict = {}
410
  if images is not None:
411
  if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
412
- batched_images = make_nested_list_of_images(images)
 
 
 
413
  _img_proc_output = self.image_processor(batched_images, return_tensors=None,
414
- **merged_call_kwargs.get("images_kwargs", {}))
415
  image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
416
  BatchFeature) else _img_proc_output
417
 
418
- if len(text) == 0 and len(batched_images) > 0: text = [" ".join([self.boi_token] * len(img_batch)) for
419
- img_batch in batched_images]
420
- if len(batched_images) != len(text): raise ValueError(
421
- f"Inconsistent batch: {len(batched_images)} images, {len(text)} texts")
 
 
 
422
 
423
  num_crops_popped = image_features_dict.pop("num_crops", None)
424
  if num_crops_popped is not None:
425
  num_crops_all = to_py_obj(num_crops_popped)
426
  temp_text_img, current_crop_idx_offset = [], 0
427
  for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
428
- crops_for_this_batch_sample = []
429
- if num_crops_all:
430
- for _ in current_imgs_in_batch:
431
  if current_crop_idx_offset < len(num_crops_all):
432
- crops_for_this_batch_sample.append(
433
- num_crops_all[current_crop_idx_offset]); current_crop_idx_offset += 1
 
 
434
  else:
435
- crops_for_this_batch_sample.append(0)
436
- image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
 
437
  processed_prompt = prompt
438
- iter_count = min(len(crops_for_this_batch_sample), len(image_indexes))
439
- for i_crop_idx in range(iter_count - 1, -1, -1):
440
- num_additional_crops = crops_for_this_batch_sample[i_crop_idx]
441
- original_token_idx = image_indexes[i_crop_idx]
442
- if num_additional_crops > 0:
443
- replacement_text = (
444
- f"Here is the original image {self.boi_token} and here are some crops to help you see better " + " ".join(
445
- [self.boi_token] * num_additional_crops))
446
- processed_prompt = processed_prompt[
447
- :original_token_idx] + replacement_text + processed_prompt[
448
- original_token_idx + len(
449
- self.boi_token):]
 
 
 
 
 
450
  temp_text_img.append(processed_prompt)
451
  text = temp_text_img
452
- text = [p.replace(self.boi_token, self.full_image_sequence) for p in text]
 
 
453
 
454
  audio_features_dict = {}
455
  if audios is not None:
456
  if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")
 
457
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
 
458
  if sampling_rate is not None: audio_call_kwargs["sampling_rate"] = sampling_rate
459
 
 
 
460
  _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
461
  audio_features_dict = _audio_proc_output.data
462
 
463
- new_text_with_audio, actual_mel_frames_per_sample = [], to_py_obj(
464
- audio_features_dict["audio_attention_mask"].sum(axis=-1))
465
- if len(actual_mel_frames_per_sample) != len(text): raise ValueError(
466
- f"Inconsistent batch for audio/text: {len(actual_mel_frames_per_sample)} audio, {len(text)} text.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  for i, prompt in enumerate(text):
469
- num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
 
 
 
470
  audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens
471
 
472
  if self.audio_placeholder_token in prompt:
473
- prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
 
474
  else:
475
- prompt += audio_token_sequence_str
476
  new_text_with_audio.append(prompt)
477
  text = new_text_with_audio
478
 
479
  text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
480
  text_features_dict = self.tokenizer(text=text, return_tensors=None,
481
- **text_tokenizer_kwargs)
482
 
 
483
  input_ids_list_of_lists = text_features_dict["input_ids"]
 
484
  if not isinstance(input_ids_list_of_lists, list) or not (
485
  input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)):
486
  if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
487
- input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)
488
  elif isinstance(input_ids_list_of_lists, list) and (
489
  not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
490
- input_ids_list_of_lists = [input_ids_list_of_lists]
491
 
492
  token_type_ids_list = []
493
  for ids_sample in input_ids_list_of_lists:
494
- types = [0] * len(ids_sample)
495
  for j, token_id_val in enumerate(ids_sample):
496
  if self.image_token_id is not None and token_id_val == self.image_token_id:
497
- types[j] = 1
498
- elif self.audio_token_id != -1 and token_id_val == self.audio_token_id:
499
- types[j] = 2
500
  token_type_ids_list.append(types)
501
  text_features_dict["token_type_ids"] = token_type_ids_list
502
 
@@ -504,6 +587,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
504
  if image_features_dict: final_batch_data.update(image_features_dict)
505
  if audio_features_dict: final_batch_data.update(audio_features_dict)
506
 
 
507
  return BatchFeature(data=final_batch_data, tensor_type=final_rt)
508
 
509
  def batch_decode(self, *args, **kwargs):
@@ -516,16 +600,29 @@ class Gemma3OmniProcessor(ProcessorMixin):
516
  def model_input_names(self) -> List[str]:
517
  input_names = set()
518
  if hasattr(self, 'tokenizer') and self.tokenizer is not None:
519
- input_names.update(self.tokenizer.model_input_names + ["token_type_ids"])
 
 
 
 
 
 
520
 
521
  if hasattr(self, 'image_processor') and self.image_processor is not None:
522
- input_names.update(self.image_processor.model_input_names)
523
-
524
- if hasattr(self, 'audio_processor') and self.audio_processor is not None and \
525
- hasattr(self.audio_processor, 'model_input_names'):
526
- input_names.update(self.audio_processor.model_input_names)
527
- elif hasattr(self,
528
- 'audio_processor') and self.audio_processor is not None:
529
- input_names.update(["audio_values", "audio_attention_mask"])
 
 
 
 
 
 
 
530
 
531
  return list(input_names)
 
1
  import re
2
+ from typing import List, Optional, Union, Dict, Any, Tuple # Added Tuple
3
 
4
  import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
+ from transformers.audio_utils import AudioInput # type: ignore
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
+ from transformers.image_utils import make_nested_list_of_images # If image processing is used
13
  from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
16
  # Constants
17
  DEFAULT_SAMPLING_RATE = 16000
18
  DEFAULT_N_FFT = 512
19
+ DEFAULT_WIN_LENGTH = 400 # Matches Phi4M's 16kHz win_length for reference
20
+ DEFAULT_HOP_LENGTH = 160 # Matches Phi4M's 16kHz hop_length for reference
21
  DEFAULT_N_MELS = 80
22
+ DEFAULT_COMPRESSION_RATE = 4 # Generic default
23
+ DEFAULT_QFORMER_RATE = 2 # Generic default
24
+ DEFAULT_FEAT_STRIDE = 4 # Generic default
25
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
  DEFAULT_MAX_LENGTH = 16384
28
+ # LOG_MEL_CLIP_EPSILON = 1e-5 # Original B's constant, A clips at 1.0
29
 
30
  logger = logging.get_logger(__name__)
31
 
32
 
33
+ # This create_mel_filterbank function is from your original Snippet B.
34
+ # It will be used by the Gemma3AudioFeatureExtractor.
35
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
36
  fmax: Optional[float] = None) -> np.ndarray:
37
  """Create Mel filterbank for audio processing."""
38
  fmax = fmax or sampling_rate / 2.0
39
 
40
+ def hz_to_mel(f: float) -> float: # Slaney scale from Snippet B
41
  return 1127.0 * math.log(1 + f / 700.0)
42
 
43
  if fmin >= fmax:
44
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
45
 
46
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
47
+ freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of Slaney hz_to_mel
48
 
49
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
50
  bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
51
+ bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output is n_fft//2
52
 
53
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
54
  for m_idx in range(n_mels):
55
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
56
 
57
+ if center > left: # Rising slope
58
  filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
59
+ if right > center: # Falling slope
60
  filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
61
 
62
  # Ensure the peak at 'center' is 1.0 if it's a valid point.
63
+ # This logic is from original Snippet B. Phi4M's speechlib_mel might normalize differently.
64
+ if left <= center <= right: # Check if center is within the bounds of the filter
65
+ if filterbank.shape[1] > center: # Check if center index is within filterbank columns
66
+ if (center > left and filterbank[m_idx, center] < 1.0 and center < right) or \
67
+ (left == center and center < right) or \
68
+ (right == center and left < center): # Ensure it's a triangular filter with a slope
69
  filterbank[m_idx, center] = 1.0
70
+ elif left == center and right == center: # Handles the case of a filter with zero width if bins are identical
71
+ filterbank[m_idx, center] = 1.0
72
+
73
  return filterbank
74
 
75
 
76
+ # --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
77
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
78
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
79
+
80
+ def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs):
81
+ feature_size = 80 # From Phi4M
82
+ sampling_rate = 16000 # From Phi4M (target sampling rate)
83
+ padding_value = 0.0 # From Phi4M
84
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
85
+
86
+ self.compression_rate = audio_compression_rate
87
+ self.qformer_compression_rate = audio_downsample_rate # In Phi4M, audio_downsample_rate is qformer_compression_rate
88
+ self.feat_stride = audio_feat_stride
89
+
90
+ self._eightk_method = kwargs.get("eightk_method", "fillzero") # 'fillzero' or 'resample'
91
+
92
+ # Using the provided create_mel_filterbank (Slaney scale)
93
+ # Parameters for Mel filterbank match Phi4M's speechlib_mel call
94
+ self._mel = create_mel_filterbank(
95
+ sampling_rate=16000, # Target sampling rate
96
+ n_fft=512, # n_fft for 16kHz audio in Phi4M
97
+ n_mels=80, # feature_size
98
+ fmin=0.0, # Phi4M's fmin is None, typically defaults to 0
99
+ fmax=7690.0 # Specific fmax from Phi4M
100
+ ).T
101
+ self._hamming400 = np.hamming(400) # for 16k audio, from Phi4M
102
+ self._hamming200 = np.hamming(200) # for 8k audio, from Phi4M
103
 
104
+ def __call__(
105
  self,
106
+ audios: List[Union[AudioInput, Tuple[np.ndarray, int]]], # More specific type hint
107
+ return_tensors: Optional[Union[str, TensorType]] = None,
 
 
 
 
 
 
 
 
 
 
108
  ):
109
+ returned_input_audio_embeds = []
110
+ returned_audio_embed_sizes = []
111
+ audio_frames_list = [] # Stores num_mel_frames * feat_stride for each audio item
112
+
113
+ for audio_input_item in audios:
114
+ if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2:
115
+ raise ValueError(
116
+ "Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)."
117
+ )
118
+ audio_data, sample_rate = audio_input_item
119
 
120
+ if isinstance(audio_data, list): # Convert list to ndarray
121
+ audio_data = np.array(audio_data, dtype=np.float32)
122
+ if not isinstance(audio_data, np.ndarray):
123
+ raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
124
 
125
+ audio_embeds_np = self._extract_features(audio_data, sample_rate) # log_fbank
 
 
 
 
 
126
 
127
+ num_mel_frames = audio_embeds_np.shape[0]
128
+ current_audio_frames = num_mel_frames * self.feat_stride # Phi4M logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
 
 
 
 
 
131
 
132
+ returned_input_audio_embeds.append(torch.from_numpy(audio_embeds_np))
133
+ returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
134
+ audio_frames_list.append(current_audio_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ padded_input_audio_embeds = pad_sequence(
137
+ returned_input_audio_embeds, batch_first=True, padding_value=self.padding_value
138
+ )
139
+ stacked_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
140
+
141
+ tensor_audio_frames_list = torch.tensor(audio_frames_list, dtype=torch.long)
142
+
143
+ max_audio_frames = 0
144
+ if len(audios) > 0 and tensor_audio_frames_list.numel() > 0:
145
+ max_audio_frames = tensor_audio_frames_list.max().item()
146
+
147
+ returned_audio_attention_mask = None
148
+ if max_audio_frames > 0: # Create mask only if there are frames
149
+ if len(audios) > 1:
150
+ returned_audio_attention_mask = torch.arange(0, max_audio_frames,
151
+ device=tensor_audio_frames_list.device).unsqueeze(
152
+ 0) < tensor_audio_frames_list.unsqueeze(1)
153
+ elif len(audios) == 1: # For batch size 1
154
+ returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool,
155
+ device=tensor_audio_frames_list.device)
156
+
157
+ data = {
158
+ "input_audio_embeds": padded_input_audio_embeds,
159
+ "audio_embed_sizes": stacked_audio_embed_sizes,
160
+ }
161
+ if returned_audio_attention_mask is not None:
162
+ data["audio_attention_mask"] = returned_audio_attention_mask
163
 
164
+ return BatchFeature(data=data, tensor_type=return_tensors)
 
 
 
165
 
166
+ def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
167
+ if wav.ndim > 1:
168
+ wav = np.squeeze(wav)
169
+ if len(wav.shape) == 2: # stereo to mono
170
+ wav = wav.mean(axis=1).astype(np.float32) # Ensure float32 after mean
171
+
172
+ wav = wav.astype(np.float32) # Ensure wav is float32
173
+
174
+ # Phi4M Resampling logic
175
+ if fs > self.sampling_rate: # self.sampling_rate is 16000
176
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, fs)
177
+ fs = self.sampling_rate
178
+ elif 8000 < fs < self.sampling_rate:
179
+ wav = scipy.signal.resample_poly(wav, 8000, fs) # Resample to 8000 first
180
+ fs = 8000
181
+ elif fs < 8000 and fs > 0:
182
+ logger.warning(f"Sample rate {fs} is less than 8000Hz. Resampling to 8000Hz.")
183
+ wav = scipy.signal.resample_poly(wav, 8000, fs)
184
+ fs = 8000
185
+ elif fs <= 0:
186
+ raise RuntimeError(f"Unsupported sample rate {fs}")
187
+
188
+ if fs == 8000:
189
+ if self._eightk_method == "resample":
190
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000) # Resample 8k to 16k
191
+ fs = self.sampling_rate
192
+ # If "fillzero", parameters for 8k will be used, and spectrum padded later.
193
+ elif fs != self.sampling_rate: # Should be 16000 if not 8000 and _eightk_method != "resample"
194
+ raise RuntimeError(
195
+ f"Audio sample rate {fs} not supported after initial processing. Expected {self.sampling_rate} or 8000.")
196
+
197
+ preemphasis_coeff = 0.97
198
+
199
+ if fs == 8000:
200
+ n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
201
+ elif fs == 16000:
202
+ n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
203
+ else:
204
+ raise RuntimeError(f"Inconsistent fs {fs} for parameter selection.")
205
 
206
+ if len(wav) < win_length:
207
+ wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
208
 
209
+ num_frames = (wav.shape[0] - win_length) // hop_length + 1
210
+ if num_frames <= 0:
211
+ return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
212
 
213
+ y_frames = np.array(
214
+ [wav[i * hop_length: i * hop_length + win_length] for i in range(num_frames)],
215
+ dtype=np.float32,
216
+ )
217
 
218
+ _y_frames_rolled = np.roll(y_frames, 1, axis=1)
219
+ _y_frames_rolled[:, 0] = _y_frames_rolled[:, 1] # Phi4M specific handling
220
+ y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
 
221
 
222
+ S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
 
223
 
224
+ if fs == 8000 and self._eightk_method == "fillzero":
225
+ # Pad spectrum to match 16kHz feature dimension (n_fft=512 -> 257 bins)
226
+ # Current S has (256 // 2) + 1 = 129 bins
227
+ target_bins = (512 // 2) + 1
228
+ pad_width = target_bins - S.shape[1]
229
+ # Phi4M: S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
230
+ # This means take all but last bin from 8k spectrum, then pad.
231
+ S_core = S[:, :-1]
232
+ padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
233
+ S = np.concatenate((S_core, padarray), axis=1)
234
 
235
+ spec = np.abs(S).astype(np.float32)
236
+ return spec
 
 
 
 
 
 
 
237
 
238
+ def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
239
+ spec = self._extract_spectrogram(wav, fs)
240
+ if spec.shape[0] == 0:
241
+ return np.zeros((0, self.feature_size), dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ spec_power = spec ** 2
244
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) # Clip at 1.0 before log (Phi4M)
245
+ log_fbank = np.log(fbank_power).astype(np.float32)
246
+ return log_fbank
247
 
248
+ def _compute_audio_embed_size(self, audio_frames: int) -> int:
249
+ # Phi4M's logic for compressed size
250
+ integer = audio_frames // self.compression_rate
251
+ remainder = audio_frames % self.compression_rate
252
+ result = integer if remainder == 0 else integer + 1
 
 
 
253
 
254
+ integer = result // self.qformer_compression_rate
255
+ remainder = result % self.qformer_compression_rate
256
+ result = integer if remainder == 0 else integer + 1
257
+ return result
 
258
 
 
259
 
260
+ # --- End of Refactored Audio Feature Extractor ---
 
 
261
 
262
 
263
  class Gemma3ImagesKwargs(ImagesKwargs):
 
290
  def __init__(
291
  self,
292
  image_processor=None,
293
+ audio_processor=None, # User can pass an instance of RefactoredGemma3... here
294
  tokenizer=None,
295
  chat_template=None,
296
  image_seq_length: int = 256,
 
313
  self.image_token = getattr(self.tokenizer, "image_token", "<image>")
314
  self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
315
 
316
+ self.audio_token_str_from_user_code = "<audio_soft_token>" # Example
317
+ # Ensure this token is actually in the tokenizer vocab as a special token
318
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
319
  if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
320
  logger.warning(
 
330
  self.image_token = "<image>"
331
  self.eoi_token = ""
332
  self.audio_token_str_from_user_code = "<audio_soft_token>"
333
+ self.audio_token_id = -1 # Placeholder if tokenizer is missing
334
  self.full_image_sequence = ""
335
 
336
+ # These attributes are specific to Gemma3OmniProcessor for its internal _compute_audio_embed_size
337
+ self.prompt_audio_compression_rate = kwargs.pop("prompt_audio_compression_rate", DEFAULT_COMPRESSION_RATE)
338
+ self.prompt_audio_qformer_rate = kwargs.pop("prompt_audio_qformer_rate", DEFAULT_QFORMER_RATE)
339
+ # self.prompt_audio_feat_stride = kwargs.pop("prompt_audio_feat_stride", DEFAULT_FEAT_STRIDE) # Not used by its _compute_audio_embed_size
340
+
341
  self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")
342
 
343
  def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
 
352
  if modality_key_in_call in final_kwargs:
353
  if isinstance(modality_kwargs_in_call, dict):
354
  final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
355
+ elif isinstance(modality_kwargs_in_call, dict): # New modality not in defaults
356
  final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
357
 
358
+ if self.tokenizer: # Ensure tokenizer exists before accessing its attributes
359
  for modality_key in final_kwargs:
360
  modality_dict = final_kwargs[modality_key]
361
+ if isinstance(modality_dict, dict): # Check if it's a dictionary
362
+ for key_in_mod_dict in list(modality_dict.keys()): # Iterate over keys
363
  if key_in_mod_dict in tokenizer_init_kwargs:
364
  value = (
365
  getattr(self.tokenizer, key_in_mod_dict)
 
368
  )
369
  modality_dict[key_in_mod_dict] = value
370
 
371
+ if "text_kwargs" not in final_kwargs: final_kwargs["text_kwargs"] = {} # Ensure text_kwargs exists
 
372
  final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
373
  final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
374
 
375
  return final_kwargs
376
 
377
  def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
378
+ # This method is part of Gemma3OmniProcessor.
379
+ # It calculates a number of soft tokens based on its own compression rates.
380
+ # Note: `audio_mel_frames` here is the number of raw Mel frames from the feature extractor's perspective
381
+ # if the attention mask sum is directly used before feat_stride scaling by the processor.
382
+ # However, if using the Refactored processor, audio_attention_mask.sum() will yield
383
+ # num_mel_frames * feat_stride. This method should then correctly compress that value.
384
+
385
+ # Using prompt_audio_compression_rate and prompt_audio_qformer_rate
386
+ # which are attributes of this Gemma3OmniProcessor class.
387
+
388
+ # First compression
389
+ # audio_mel_frames here should ideally be num_actual_mel_frames * feat_stride_of_the_audio_processor
390
+ # if trying to match the number of tokens from a Phi4M-style processor.
391
+ # The refactored audio processor does this scaling internally before its own _compute_audio_embed_size.
392
+ # If actual_mel_frames_per_sample (from sum of attention_mask) *is* already scaled by feat_stride
393
+ # (as it would be if using the refactored processor's attention_mask), then this calculation is correct.
394
+
395
+ integer = audio_mel_frames // self.prompt_audio_compression_rate
396
+ remainder = audio_mel_frames % self.prompt_audio_compression_rate
397
+ result = integer if remainder == 0 else integer + 1
398
+
399
+ # Second compression
400
+ integer = result // self.prompt_audio_qformer_rate
401
+ remainder = result % self.prompt_audio_qformer_rate
402
+ result = integer if remainder == 0 else integer + 1
403
+ return result
404
 
405
  def __call__(
406
  self,
407
  text: Union[str, List[str]] = None,
408
  images: Optional[Any] = None,
409
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
410
+ sampling_rate: Optional[int] = None, # sampling_rate for raw audio arrays
411
  return_tensors: Optional[Union[str, TensorType]] = None,
412
  **kwargs: Any
413
  ) -> BatchFeature:
414
  if text is None and images is None and audios is None:
415
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
416
 
417
+ final_rt = return_tensors # Store original return_tensors
418
+ # Properly merge kwargs for text, images, audio
419
  merged_call_kwargs = self._merge_kwargs(
420
+ Gemma3ProcessorKwargs, # The class defining _defaults
421
+ self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {}, # Tokenizer defaults
422
+ **kwargs # User-provided kwargs from the call
423
  )
424
 
425
+ # Determine final return_tensors, prioritizing call > text_kwargs > default
426
+ if final_rt is None: # If not specified in call
427
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
428
+ else: # If specified in call, remove from text_kwargs to avoid conflict
429
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
430
 
431
+ if text is None: # If no text, create empty strings based on other inputs
432
  num_samples = 0
433
  if images is not None:
434
  _images_list = images if isinstance(images, list) and (
435
  not images or not isinstance(images[0], (int, float))) else [images]
436
  num_samples = len(_images_list)
437
  elif audios is not None:
438
+ _audios_list = audios if isinstance(audios, list) and not (
439
+ isinstance(audios[0], tuple) and isinstance(audios[0][0], (int, float))) else [
440
+ audios] # check if audios is list of items or list of (wave,sr)
441
  num_samples = len(_audios_list)
442
+ text = [""] * num_samples if num_samples > 0 else [""] # Default to one empty string if no inputs
443
 
444
+ if isinstance(text, str): text = [text] # Ensure text is a list
445
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
446
  raise ValueError("Input `text` must be a string or a list of strings.")
447
 
448
  image_features_dict = {}
449
  if images is not None:
450
  if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
451
+ # Ensure images are correctly batched
452
+ batched_images = make_nested_list_of_images(images) # handles various image input types
453
+
454
+ _img_kwargs = merged_call_kwargs.get("images_kwargs", {})
455
  _img_proc_output = self.image_processor(batched_images, return_tensors=None,
456
+ **_img_kwargs) # Pass None to handle tensors later
457
  image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
458
  BatchFeature) else _img_proc_output
459
 
460
+ if len(text) == 1 and text[0] == "" and len(
461
+ batched_images) > 0: # If text is default empty and images exist
462
+ text = [" ".join([self.boi_token] * len(img_batch)) for img_batch in batched_images]
463
+ elif len(batched_images) != len(text): # If text was provided, ensure consistency
464
+ raise ValueError(
465
+ f"Inconsistent batch: {len(batched_images)} image groups, {len(text)} texts. Ensure one text prompt per image group."
466
+ )
467
 
468
  num_crops_popped = image_features_dict.pop("num_crops", None)
469
  if num_crops_popped is not None:
470
  num_crops_all = to_py_obj(num_crops_popped)
471
  temp_text_img, current_crop_idx_offset = [], 0
472
  for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
473
+ crops_for_this_batch_sample = [] # Number of *additional* crops for each original image
474
+ if num_crops_all: # If num_crops_all is not None or empty
475
+ for _ in current_imgs_in_batch: # For each original image in the current batch sample
476
  if current_crop_idx_offset < len(num_crops_all):
477
+ # num_crops_all contains total items (original + crops) for each image
478
+ # We need number of *additional* crops. Assuming num_crops_all[i] >= 1
479
+ crops_for_this_batch_sample.append(max(0, num_crops_all[current_crop_idx_offset] - 1))
480
+ current_crop_idx_offset += 1
481
  else:
482
+ crops_for_this_batch_sample.append(0) # Should not happen if num_crops_all is correct
483
+
484
+ image_placeholders_in_prompt = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
485
  processed_prompt = prompt
486
+
487
+ # Iterate backwards to preserve indices for replacement
488
+ iter_count = min(len(crops_for_this_batch_sample), len(image_placeholders_in_prompt))
489
+ for i_placeholder_idx in range(iter_count - 1, -1, -1):
490
+ num_additional_crops_for_this_image = crops_for_this_batch_sample[i_placeholder_idx]
491
+ original_token_idx_in_prompt = image_placeholders_in_prompt[i_placeholder_idx]
492
+
493
+ if num_additional_crops_for_this_image > 0:
494
+ # Create replacement text: original image placeholder + placeholders for additional crops
495
+ replacement_text = self.boi_token + "".join(
496
+ [self.boi_token] * num_additional_crops_for_this_image)
497
+ # Replace the single original boi_token with the new sequence
498
+ processed_prompt = (
499
+ processed_prompt[:original_token_idx_in_prompt] +
500
+ replacement_text +
501
+ processed_prompt[original_token_idx_in_prompt + len(self.boi_token):]
502
+ )
503
  temp_text_img.append(processed_prompt)
504
  text = temp_text_img
505
+ # Replace all BOI tokens with the full image sequence (BOI + IMAGE*N + EOI)
506
+ # This step assumes that if additional crops were handled, self.boi_token still marks each image.
507
+ text = [p.replace(self.boi_token, self.full_image_sequence) for p in text]
508
 
509
  audio_features_dict = {}
510
  if audios is not None:
511
  if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")
512
+
513
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
514
+ # Pass sampling_rate from __call__ to audio_processor if provided (for raw arrays)
515
  if sampling_rate is not None: audio_call_kwargs["sampling_rate"] = sampling_rate
516
 
517
+ # The audio_processor (e.g., RefactoredGemma3...) will return its model_input_names
518
+ # e.g., {"input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"}
519
  _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
520
  audio_features_dict = _audio_proc_output.data
521
 
522
+ new_text_with_audio = []
523
+
524
+ # Determine the number of actual audio items processed by the audio_processor
525
+ # This should match len(text) if batching is consistent.
526
+ # The 'audio_attention_mask' or 'input_audio_embeds' can indicate this.
527
+ num_audio_samples_processed = audio_features_dict[self.audio_processor.model_input_names[0]].shape[0]
528
+
529
+ if num_audio_samples_processed != len(text):
530
+ raise ValueError(
531
+ f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts."
532
+ )
533
+
534
+ # If using Gemma3AudioFeatureExtractor,
535
+ # "audio_embed_sizes" is already computed correctly (num compressed tokens).
536
+ # The processor's own _compute_audio_embed_size is called to determine how many
537
+ # self.audio_token_str_from_user_code to insert. Ideally, this matches.
538
+
539
+ # Get the number of frames that the processor's _compute_audio_embed_size expects.
540
+ # If the audio_processor is RefactoredGemma3..., its attention_mask is over (num_mel_frames * feat_stride).
541
+ # So, sum of that mask gives the input for this processor's _compute_audio_embed_size.
542
+ frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum(
543
+ axis=-1)) # sum of audio_attention_mask
544
 
545
  for i, prompt in enumerate(text):
546
+ # num_soft_tokens should be the final number of audio tokens to insert in the text.
547
+ # This is calculated by the Gemma3OmniProcessor's own method.
548
+ num_soft_tokens = self._compute_audio_embed_size(frames_for_embed_size_calc[i])
549
+
550
  audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens
551
 
552
  if self.audio_placeholder_token in prompt:
553
+ prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str,
554
+ 1) # Replace only first
555
  else:
556
+ prompt += audio_token_sequence_str # Append if no placeholder
557
  new_text_with_audio.append(prompt)
558
  text = new_text_with_audio
559
 
560
  text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
561
  text_features_dict = self.tokenizer(text=text, return_tensors=None,
562
+ **text_tokenizer_kwargs) # Pass None for tensors
563
 
564
+ # Create token_type_ids
565
  input_ids_list_of_lists = text_features_dict["input_ids"]
566
+ # Ensure it's a list of lists
567
  if not isinstance(input_ids_list_of_lists, list) or not (
568
  input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)):
569
  if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
570
+ input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists) # to nested python lists
571
  elif isinstance(input_ids_list_of_lists, list) and (
572
  not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
573
+ input_ids_list_of_lists = [input_ids_list_of_lists] # wrap single list
574
 
575
  token_type_ids_list = []
576
  for ids_sample in input_ids_list_of_lists:
577
+ types = [0] * len(ids_sample) # 0 for text
578
  for j, token_id_val in enumerate(ids_sample):
579
  if self.image_token_id is not None and token_id_val == self.image_token_id:
580
+ types[j] = 1 # 1 for image
581
+ elif self.audio_token_id != -1 and token_id_val == self.audio_token_id: # Check if audio_token_id is valid
582
+ types[j] = 2 # 2 for audio
583
  token_type_ids_list.append(types)
584
  text_features_dict["token_type_ids"] = token_type_ids_list
585
 
 
587
  if image_features_dict: final_batch_data.update(image_features_dict)
588
  if audio_features_dict: final_batch_data.update(audio_features_dict)
589
 
590
+ # Convert all data to tensors if final_rt is specified
591
  return BatchFeature(data=final_batch_data, tensor_type=final_rt)
592
 
593
  def batch_decode(self, *args, **kwargs):
 
600
  def model_input_names(self) -> List[str]:
601
  input_names = set()
602
  if hasattr(self, 'tokenizer') and self.tokenizer is not None:
603
+ # Make sure model_input_names is a list/set before +
604
+ tokenizer_inputs = self.tokenizer.model_input_names
605
+ if isinstance(tokenizer_inputs, (list, set)):
606
+ input_names.update(tokenizer_inputs)
607
+ else: # Fallback if it's a single string
608
+ input_names.add(str(tokenizer_inputs))
609
+ input_names.add("token_type_ids")
610
 
611
  if hasattr(self, 'image_processor') and self.image_processor is not None:
612
+ # Similar check for image_processor
613
+ image_inputs = self.image_processor.model_input_names
614
+ if isinstance(image_inputs, (list, set)):
615
+ input_names.update(image_inputs)
616
+ else:
617
+ input_names.add(str(image_inputs))
618
+
619
+ if hasattr(self, 'audio_processor') and self.audio_processor is not None:
620
+ # Use model_input_names from the instantiated audio_processor
621
+ # This will correctly reflect the names from RefactoredGemma3... if it's used.
622
+ audio_inputs = self.audio_processor.model_input_names
623
+ if isinstance(audio_inputs, (list, set)):
624
+ input_names.update(audio_inputs)
625
+ else:
626
+ input_names.add(str(audio_inputs))
627
 
628
  return list(input_names)