voidful commited on
Commit
0173a9f
·
verified ·
1 Parent(s): 9faac02

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +136 -138
processing_gemma3_omni.py CHANGED
@@ -7,11 +7,12 @@ import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
  # Using the original AudioInput for minimal change from your provided code
10
- from transformers.audio_utils import AudioInput # type: ignore
11
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
12
  from transformers.feature_extraction_utils import BatchFeature
13
  from transformers.image_utils import make_nested_list_of_images
14
- from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs # Removed Unpack as it's not standard
 
15
  from transformers.utils import TensorType, to_py_obj, logging
16
 
17
  # Constants
@@ -26,7 +27,7 @@ DEFAULT_FEAT_STRIDE = 4
26
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
27
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
28
  DEFAULT_MAX_LENGTH = 16384
29
- LOG_MEL_CLIP_EPSILON = 1e-5 # Epsilon for log mel clipping
30
 
31
  logger = logging.get_logger(__name__)
32
 
@@ -34,19 +35,18 @@ logger = logging.get_logger(__name__)
34
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
35
  fmax: Optional[float] = None) -> np.ndarray:
36
  """Create Mel filterbank for audio processing. (User's version)"""
37
- fmax = fmax or sampling_rate / 2.0 # Ensure float division
38
 
39
  # User's Mel scale formula
40
  def hz_to_mel(f: float) -> float:
41
  return 1127.0 * math.log(1 + f / 700.0)
42
-
43
- def mel_to_hz(mel: float) -> float: # Added for completeness if needed
44
- return 700.0 * (math.exp(mel / 1127.0) - 1)
45
 
 
 
46
 
47
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
48
  # freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Original
49
- freq_points = mel_to_hz(mel_points) # Using the inverse function
50
 
51
  # Clip freq_points to be within [0, sampling_rate/2]
52
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
@@ -55,12 +55,11 @@ def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: flo
55
  # Ensure bins are within valid range for rfft output indices
56
  bins = np.clip(bins, 0, n_fft // 2)
57
 
58
-
59
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
60
- for m_idx in range(n_mels): # Loop from 0 to n_mels-1 to fill filterbank[m_idx]
61
  # Bins for (m_idx)-th filter are bins[m_idx], bins[m_idx+1], bins[m_idx+2]
62
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
63
-
64
  # Original logic for applying triangular filter
65
  # Ensure no division by zero if points coincide
66
  if center > left:
@@ -69,9 +68,8 @@ def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: flo
69
  filterbank[m_idx, center:right] = (right - np.arange(center, right)) / (right - center)
70
  # If left=center or center=right, the corresponding slope is zero, which is implicitly handled.
71
  # Ensure peak is 1.0 if center is a valid point within a slope.
72
- if left <= center < right and center > left : # If center forms a peak of a valid triangle part
73
- filterbank[m_idx, center] = 1.0
74
-
75
 
76
  return filterbank
77
 
@@ -84,14 +82,14 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
84
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
85
  qformer_rate: int = DEFAULT_QFORMER_RATE,
86
  feat_stride: int = DEFAULT_FEAT_STRIDE,
87
- sampling_rate: int = DEFAULT_SAMPLING_RATE, # Target sampling rate
88
  n_fft: int = DEFAULT_N_FFT,
89
  win_length: Optional[int] = None,
90
  hop_length: Optional[int] = None,
91
  n_mels: int = DEFAULT_N_MELS,
92
- f_min: float = 0.0, # Added for mel filterbank control
93
- f_max: Optional[float] = None, # Added for mel filterbank control
94
- padding_value: float = 0.0, # Explicitly define for clarity
95
  **kwargs
96
  ):
97
  _win_length = win_length if win_length is not None else n_fft
@@ -100,7 +98,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
100
  # feature_size is n_mels for the superclass
101
  super().__init__(
102
  feature_size=n_mels,
103
- sampling_rate=sampling_rate, # This sets self.sampling_rate
104
  padding_value=padding_value,
105
  **kwargs
106
  )
@@ -115,32 +113,32 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
115
  self.hop_length = _hop_length
116
  self.n_mels = n_mels
117
  self.f_min = f_min
118
- self.f_max = f_max # Will be sampling_rate/2 if None in create_mel_filterbank call
119
 
120
  if self.win_length > self.n_fft:
121
  logger.warning(
122
  f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
123
  "Window will be applied, then data will be zero-padded/truncated to n_fft by np.fft.rfft."
124
  )
125
- self.window = np.hamming(self.win_length).astype(np.float32) # Or scipy.signal.get_window("hann", self.win_length)
 
126
  self.mel_filterbank = create_mel_filterbank(
127
  self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
128
- ).T # Transpose for dot product: (n_fft // 2 + 1, n_mels)
129
-
130
 
131
  def __call__(
132
  self,
133
- audios: Union[AudioInput, List[AudioInput]], # Accept single or list
134
- sampling_rate: Optional[int] = None, # To specify SR if audios are raw arrays
135
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
136
  ) -> BatchFeature:
137
-
138
  if not isinstance(audios, list):
139
  audios = [audios]
140
 
141
  processed_mels: List[torch.Tensor] = []
142
  actual_mel_lengths: List[int] = []
143
-
144
  # Kept from user's code - their purpose might be for token calculation downstream
145
  sizes_for_embed_length: List[torch.Tensor] = []
146
  frames_scaled_by_feat_stride: List[int] = []
@@ -151,7 +149,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
151
 
152
  if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
153
  current_wav, source_sr = audio_item
154
- current_wav = np.asarray(current_wav, dtype=np.float32) # Ensure float32 numpy array
155
  elif isinstance(audio_item, (np.ndarray, list)):
156
  current_wav = np.asarray(audio_item, dtype=np.float32)
157
  if sampling_rate is None:
@@ -170,13 +168,13 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
170
  f"Unsupported audio input type: {type(audio_item)}. "
171
  "Expected np.ndarray, list of floats, or Tuple[np.ndarray, int]."
172
  )
173
-
174
  processed_wav_array = self._preprocess_audio(current_wav, source_sr)
175
- mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav_array) # Shape: (T_mel, N_Mels)
176
-
177
- feature_tensor = torch.from_numpy(mel_spectrogram) # Already float32
178
  processed_mels.append(feature_tensor)
179
- actual_mel_lengths.append(feature_tensor.shape[0]) # T_mel for this item
180
 
181
  # User's original logic for 'sizes' and 'frames'
182
  sizes_for_embed_length.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
@@ -188,16 +186,16 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
188
 
189
  # Create attention mask corresponding to the actual lengths of mel spectrograms
190
  max_t_mel_in_batch = audio_embeds.shape[1]
191
- current_device = audio_embeds.device # Get device from padded tensor if using PyTorch tensors earlier
192
-
193
  # Create attention mask directly based on actual_mel_lengths
194
  attention_mask = torch.zeros(len(audios), max_t_mel_in_batch, dtype=torch.bool, device=current_device)
195
  for i, length in enumerate(actual_mel_lengths):
196
  attention_mask[i, :length] = True
197
-
198
  output_data = {
199
  "audio_values": audio_embeds,
200
- "audio_attention_mask": attention_mask # Correctly shaped mask for audio_values
201
  }
202
 
203
  # Include user's 'sizes' if they are needed downstream
@@ -211,7 +209,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
211
  # Ensure wav is float32
212
  if wav.dtype not in [np.float32, np.float64]:
213
  if np.issubdtype(wav.dtype, np.integer):
214
- max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0 # Avoid error on empty array
215
  wav = wav.astype(np.float32) / max_val
216
  else:
217
  wav = wav.astype(np.float32)
@@ -219,20 +217,20 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
219
  wav = wav.astype(np.float32)
220
 
221
  if wav.ndim > 1:
222
- wav = wav.mean(axis=0) # Convert to mono
223
-
224
  if source_sr != self.sampling_rate:
225
  logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
226
  # Calculate integer up/down factors for resample_poly
227
  common_divisor = math.gcd(self.sampling_rate, source_sr)
228
  up_factor = self.sampling_rate // common_divisor
229
  down_factor = source_sr // common_divisor
230
- if up_factor != down_factor : # Only if actual resampling is needed
231
  wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
232
-
233
  # Normalize amplitude to roughly [-1, 1]
234
  max_abs_val = np.abs(wav).max()
235
- if max_abs_val > 1e-7: # Avoid division by zero or tiny numbers
236
  wav = wav / max_abs_val
237
  return wav
238
 
@@ -245,10 +243,10 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
245
  # Calculate number of frames
246
  # This calculation ensures at least one frame if len(wav) == self.win_length
247
  if len(wav) >= self.win_length:
248
- num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
249
- else: # Should be covered by padding, but as safeguard
250
- num_frames = 0
251
-
252
  if num_frames <= 0:
253
  logger.warning(f"Audio is too short (length {len(wav)}) to produce any frames "
254
  f"with win_length {self.win_length} and hop_length {self.hop_length}. "
@@ -263,21 +261,21 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
263
  strides=(strides * self.hop_length, strides),
264
  writeable=False
265
  )
266
- frames_data = frames_view.copy() # Important: copy after as_strided if modifying
267
-
268
- frames_data *= self.window # Apply window in-place on the copy
269
 
270
  # Compute STFT (rfft for real inputs)
271
  # n_fft determines zero-padding or truncation for FFT input from each frame
272
  spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
273
- power = np.abs(spectrum)**2
274
-
275
- mel_spectrogram = np.dot(power, self.mel_filterbank) # (num_frames, n_mels)
276
-
277
  # Clip and take log
278
- mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None) # Use defined epsilon
279
  log_mel_spectrogram = np.log(mel_spectrogram)
280
-
281
  return log_mel_spectrogram.astype(np.float32)
282
 
283
  def _calculate_embed_length(self, frame_count: int) -> int:
@@ -286,7 +284,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
286
  return math.ceil(compressed / self.qformer_rate)
287
 
288
 
289
- class Gemma3ImagesKwargs(ImagesKwargs): # User's definition
290
  do_pan_and_scan: Optional[bool]
291
  pan_and_scan_min_crop_size: Optional[int]
292
  pan_and_scan_max_num_crops: Optional[int]
@@ -294,11 +292,11 @@ class Gemma3ImagesKwargs(ImagesKwargs): # User's definition
294
  do_convert_rgb: Optional[bool]
295
 
296
 
297
- class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): # User's definition
298
  images_kwargs: Dict[str, Any]
299
  audio_kwargs: Dict[str, Any]
300
  # Added text_kwargs as it's commonly part of such structures
301
- text_kwargs: Optional[Dict[str, Any]] = None
302
  _defaults = {
303
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
304
  "images_kwargs": {},
@@ -308,30 +306,30 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): # User's definition
308
 
309
  class Gemma3OmniProcessor(ProcessorMixin):
310
  attributes = ["image_processor", "audio_processor", "tokenizer"]
311
- valid_kwargs = ["chat_template", "image_seq_length"] # From user's code
312
 
313
  # --- FIXED CLASS ATTRIBUTES ---
314
- image_processor_class = "AutoImageProcessor" # As in user's original code
315
- audio_processor_class = Gemma3AudioFeatureExtractor # Corrected to custom class
316
- tokenizer_class = "AutoTokenizer" # As in user's original code
317
 
318
  def __init__(
319
  self,
320
- image_processor=None, # Allow None, superclass or from_pretrained handles loading via _class
321
- audio_processor=None, # Allow None or instance
322
- tokenizer=None, # Allow None or instance
323
  chat_template=None,
324
  image_seq_length: int = 256,
325
- **kwargs
326
  ):
327
  # The ProcessorMixin's __init__ will handle instantiating these if they are None,
328
  # using the respective *_class attributes.
329
  # If specific instances are passed, they will be used.
330
-
331
  # Retaining user's specific logic for setting attributes if needed,
332
  # though much of this might be handled by super() or better placed after super()
333
  self.image_seq_length = image_seq_length
334
-
335
  # These tokenizer-dependent attributes should be set *after* super().__init__
336
  # ensures self.tokenizer is populated, or if tokenizer is passed directly.
337
  # If tokenizer is None and loaded by super(), these need to be set post-super().
@@ -340,53 +338,53 @@ class Gemma3OmniProcessor(ProcessorMixin):
340
  # This is a basic placeholder; HF's from_pretrained mechanism is more robust for loading
341
  # For now, we'll assume if tokenizer is None, super() handles it or it's an error later.
342
  pass
343
- else: # Tokenizer was provided
344
- self.image_token_id = getattr(tokenizer, "image_token_id", None) # More robust with getattr
345
- self.boi_token = getattr(tokenizer, "boi_token", "<|image|>") # Defaulting if not present
346
  self.image_token = getattr(tokenizer, "image_token", "<|image|>")
347
- self.eoi_token = getattr(tokenizer, "eoi_token", "") # Added eoi_token as it was used
348
 
349
- self.audio_token = "<audio_soft_token>" # User's definition
350
  # self.expected_audio_token_id = 262143 # User's reference
351
  # The existence of this token should be ensured when the tokenizer is prepared/saved.
352
- self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
353
  # if self.audio_token_id != self.expected_audio_token_id: # User's warning
354
  # logger.warning(...)
355
  if self.audio_token_id == tokenizer.unk_token_id:
356
- logger.warning(f"Audio token '{self.audio_token}' not found in tokenizer, maps to UNK. Ensure it's added.")
357
-
358
 
359
  self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token if hasattr(tokenizer, 'eoi_token') else ''}\n\n"
360
 
361
-
362
  # These seem specific to this processor's logic for determining audio token sequence length
363
  # It's better to initialize them here.
364
  self.audio_prompt_compression_rate = kwargs.pop("audio_prompt_compression_rate", 8)
365
  self.audio_prompt_qformer_rate = kwargs.pop("audio_prompt_qformer_rate", 1)
366
  self.audio_prompt_feat_stride = kwargs.pop("audio_prompt_feat_stride", 1)
367
 
368
-
369
  super().__init__(
370
  image_processor=image_processor,
371
  audio_processor=audio_processor,
372
  tokenizer=tokenizer,
373
  chat_template=chat_template,
374
- **kwargs # Pass remaining kwargs to super
375
  )
376
-
377
  # If tokenizer was loaded by super(), set tokenizer-dependent attributes now
378
  if not hasattr(self, 'image_token_id') and self.tokenizer is not None:
379
- self.image_token_id = getattr(self.tokenizer, "image_token_id", self.tokenizer.unk_token_id if hasattr(self.tokenizer, "unk_token_id") else None)
 
 
380
  self.boi_token = getattr(self.tokenizer, "boi_token", "<|image|>")
381
  self.image_token = getattr(self.tokenizer, "image_token", "<|image|>")
382
  self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
383
  self.audio_token = "<audio_soft_token>"
384
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token)
385
  if self.audio_token_id == self.tokenizer.unk_token_id:
386
- logger.warning(f"Audio token '{self.audio_token}' not found in tokenizer (post-super), maps to UNK. Ensure it's added.")
 
387
  self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}\n\n"
388
 
389
-
390
  def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs_from_call):
391
  # User's original _merge_kwargs logic
392
  default_kwargs = {}
@@ -400,17 +398,16 @@ class Gemma3OmniProcessor(ProcessorMixin):
400
 
401
  for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
402
  if modality_key_in_call in default_kwargs:
403
- if isinstance(modality_kwargs_in_call, dict):
404
  default_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
405
- elif isinstance(modality_kwargs_in_call, dict): # New modality not in defaults
406
- default_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
407
-
408
 
409
  # Update defaults with tokenizer init kwargs (original logic)
410
- for modality_key in default_kwargs: # Iterate over current keys in default_kwargs
411
  modality_dict = default_kwargs[modality_key]
412
- if isinstance(modality_dict, dict): # Ensure it's a dict before trying to access keys
413
- for key_in_mod_dict in list(modality_dict.keys()): # Iterate over copy of keys
414
  if key_in_mod_dict in tokenizer_init_kwargs:
415
  value = (
416
  getattr(self.tokenizer, key_in_mod_dict)
@@ -418,13 +415,13 @@ class Gemma3OmniProcessor(ProcessorMixin):
418
  else tokenizer_init_kwargs[key_in_mod_dict]
419
  )
420
  modality_dict[key_in_mod_dict] = value
421
-
422
  # Ensure text_kwargs processing (original logic)
423
- if "text_kwargs" not in default_kwargs: # Ensure text_kwargs exists
424
  default_kwargs["text_kwargs"] = {}
425
  default_kwargs["text_kwargs"]["truncation"] = default_kwargs["text_kwargs"].get("truncation", False)
426
- default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
427
-
428
 
429
  return default_kwargs
430
 
@@ -436,14 +433,14 @@ class Gemma3OmniProcessor(ProcessorMixin):
436
  def __call__(
437
  self,
438
  images=None,
439
- text:Union[str, List[str]]=None, # text is optional but often primary
440
  # videos=None, # Removed 'videos' as it's not handled
441
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
442
- sampling_rate: Optional[int] = None, # For audio_processor if audios are raw arrays
443
  return_tensors: Optional[Union[str, TensorType]] = None,
444
- **kwargs: Any # Replaced Unpack for broader compatibility here
445
  ) -> BatchFeature:
446
- if text is None and images is None and audios is None: # Added audios to check
447
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
448
 
449
  # Determine final return_tensors strategy
@@ -452,11 +449,11 @@ class Gemma3OmniProcessor(ProcessorMixin):
452
  # This call to _merge_kwargs primarily populates kwargs for each modality if passed in __call__
453
  # e.g. if user calls proc(..., text_kwargs={...})
454
  merged_call_kwargs = self._merge_kwargs(
455
- Gemma3ProcessorKwargs,
456
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, "init_kwargs") else {},
457
  **kwargs
458
  )
459
-
460
  # If return_tensors wasn't passed to __call__, try to get it from merged text_kwargs
461
  # and remove it from there to avoid passing it twice to tokenizer.
462
  # Default to PYTORCH if still None.
@@ -465,17 +462,17 @@ class Gemma3OmniProcessor(ProcessorMixin):
465
  else:
466
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
467
 
468
-
469
  # Standardize text input
470
- if text is None: # If no text given, create dummy text based on other modalities
471
  num_samples = 0
472
  if images is not None:
473
- _images_list = images if isinstance(images, list) and (not images or not isinstance(images[0], (int,float))) else [images]
 
474
  num_samples = len(_images_list)
475
  elif audios is not None:
476
  _audios_list = audios if isinstance(audios, list) else [audios]
477
  num_samples = len(_audios_list)
478
- text = [""] * num_samples if num_samples > 0 else [""] # Fallback for safety
479
 
480
  if isinstance(text, str):
481
  text = [text]
@@ -485,19 +482,20 @@ class Gemma3OmniProcessor(ProcessorMixin):
485
  # --- Image Processing ---
486
  image_features_dict = {}
487
  if images is not None and self.image_processor is not None:
488
- batched_images = make_nested_list_of_images(images) # HF utility
489
  # Assuming image_processor returns a dict or BatchFeature. If BatchFeature, get .data
490
- _img_proc_output = self.image_processor(batched_images, return_tensors=None, **merged_call_kwargs.get("images_kwargs", {}))
491
- image_features_dict = _img_proc_output.data if isinstance(_img_proc_output, BatchFeature) else _img_proc_output
492
-
 
493
 
494
- if len(batched_images) != len(text): # Validate batch consistency
495
  raise ValueError(f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts")
496
 
497
  # User's original image token replacement logic (complex, depends on num_crops etc from image_processor output)
498
  # This part needs to be carefully adapted based on actual image_processor output structure
499
  # For now, a simplified placeholder for the concept:
500
- if "num_crops" in image_features_dict: # Example check
501
  num_crops_list = to_py_obj(image_features_dict.pop("num_crops"))
502
  # ... user's original logic for text modification with self.full_image_sequence ...
503
  # This was: text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
@@ -505,8 +503,8 @@ class Gemma3OmniProcessor(ProcessorMixin):
505
  # For simplicity, assuming one image sequence per text for now if an image is present.
506
  temp_text = []
507
  for i, prompt in enumerate(text):
508
- if i < len(batched_images): # if this text sample has corresponding images
509
- # Replace first boi_token or append if not found
510
  if self.boi_token in prompt:
511
  temp_text.append(prompt.replace(self.boi_token, self.full_image_sequence, 1))
512
  else:
@@ -515,14 +513,13 @@ class Gemma3OmniProcessor(ProcessorMixin):
515
  temp_text.append(prompt)
516
  text = temp_text
517
 
518
-
519
  # --- Audio Processing ---
520
  audio_features_dict = {}
521
  if audios is not None and self.audio_processor is not None:
522
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
523
  if sampling_rate is not None:
524
- audio_call_kwargs["sampling_rate"] = sampling_rate
525
-
526
  # audio_processor.__call__ returns BatchFeature, get its .data attribute for the dict
527
  _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
528
  audio_features_dict = _audio_proc_output.data
@@ -533,21 +530,22 @@ class Gemma3OmniProcessor(ProcessorMixin):
533
  actual_mel_frames_per_sample = to_py_obj(audio_features_dict["audio_attention_mask"].sum(axis=-1))
534
 
535
  if len(actual_mel_frames_per_sample) != len(text):
536
- raise ValueError(f"Inconsistent batch sizes for audio and text: {len(actual_mel_frames_per_sample)} audio samples, {len(text)} texts.")
 
537
 
538
  for i, prompt in enumerate(text):
539
  num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
540
- audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens # Repeat soft token string
541
-
542
  # Replace a placeholder or append
543
- placeholder = getattr(self, "audio_placeholder_token", "<|audio|>") # Use defined placeholder
544
  if placeholder in prompt:
545
  prompt_with_audio = prompt.replace(placeholder, audio_token_sequence_str, 1)
546
- else:
547
- prompt_with_audio = prompt + audio_token_sequence_str
548
  new_text_with_audio_tokens.append(prompt_with_audio)
549
  text = new_text_with_audio_tokens
550
-
551
  # --- Text Tokenization ---
552
  text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
553
  # Tokenize the (potentially modified) text, request lists/np arrays
@@ -562,28 +560,28 @@ class Gemma3OmniProcessor(ProcessorMixin):
562
  if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
563
  input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)
564
  elif isinstance(input_ids_list_of_lists, list) and \
565
- (not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
566
- input_ids_list_of_lists = [input_ids_list_of_lists] # Batch of 1
567
 
568
  mm_token_type_ids_list = []
569
  for ids_sample in input_ids_list_of_lists:
570
- type_ids_sample = [0] * len(ids_sample) # Default type 0 (text)
571
  for idx, token_id_val in enumerate(ids_sample):
572
  if self.image_token_id is not None and token_id_val == self.image_token_id:
573
- type_ids_sample[idx] = 1 # Image token type
574
- elif token_id_val == self.audio_token_id: # Compare with ID of <audio_soft_token>
575
- type_ids_sample[idx] = 2 # Audio token type
576
  mm_token_type_ids_list.append(type_ids_sample)
577
  text_features_dict["token_type_ids"] = mm_token_type_ids_list
578
-
579
  # Combine all features
580
  final_batch_data = {**text_features_dict}
581
- if image_features_dict:
582
  final_batch_data.update(image_features_dict)
583
- if audio_features_dict:
584
  final_batch_data.update(audio_features_dict)
585
-
586
- return BatchFeature(data=final_batch_data, tensor_type=final_rt) # Use determined final_rt
587
 
588
  def batch_decode(self, *args, **kwargs):
589
  return self.tokenizer.batch_decode(*args, **kwargs)
@@ -595,11 +593,11 @@ class Gemma3OmniProcessor(ProcessorMixin):
595
  def model_input_names(self):
596
  tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
597
  image_processor_inputs = []
598
- if self.image_processor is not None: # Check if image_processor exists
599
- image_processor_inputs = self.image_processor.model_input_names
600
-
601
  audio_processor_inputs = []
602
- if self.audio_processor is not None: # Check if audio_processor exists
603
  # These are the keys Gemma3AudioFeatureExtractor puts in its output BatchFeature.data
604
  audio_processor_inputs = ["audio_values", "audio_attention_mask"]
605
  # "audio_values_sizes" was in user's original Gemma3AudioFeatureExtractor output,
 
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
  # Using the original AudioInput for minimal change from your provided code
10
+ from transformers.audio_utils import AudioInput # type: ignore
11
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
12
  from transformers.feature_extraction_utils import BatchFeature
13
  from transformers.image_utils import make_nested_list_of_images
14
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, \
15
+ ImagesKwargs # Removed Unpack as it's not standard
16
  from transformers.utils import TensorType, to_py_obj, logging
17
 
18
  # Constants
 
27
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
28
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
29
  DEFAULT_MAX_LENGTH = 16384
30
+ LOG_MEL_CLIP_EPSILON = 1e-5 # Epsilon for log mel clipping
31
 
32
  logger = logging.get_logger(__name__)
33
 
 
35
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
36
  fmax: Optional[float] = None) -> np.ndarray:
37
  """Create Mel filterbank for audio processing. (User's version)"""
38
+ fmax = fmax or sampling_rate / 2.0 # Ensure float division
39
 
40
  # User's Mel scale formula
41
  def hz_to_mel(f: float) -> float:
42
  return 1127.0 * math.log(1 + f / 700.0)
 
 
 
43
 
44
+ def mel_to_hz(mel: float) -> float: # Added for completeness if needed
45
+ return 700.0 * (math.exp(mel / 1127.0) - 1)
46
 
47
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
48
  # freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Original
49
+ freq_points = mel_to_hz(mel_points) # Using the inverse function
50
 
51
  # Clip freq_points to be within [0, sampling_rate/2]
52
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
 
55
  # Ensure bins are within valid range for rfft output indices
56
  bins = np.clip(bins, 0, n_fft // 2)
57
 
 
58
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
59
+ for m_idx in range(n_mels): # Loop from 0 to n_mels-1 to fill filterbank[m_idx]
60
  # Bins for (m_idx)-th filter are bins[m_idx], bins[m_idx+1], bins[m_idx+2]
61
  left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
62
+
63
  # Original logic for applying triangular filter
64
  # Ensure no division by zero if points coincide
65
  if center > left:
 
68
  filterbank[m_idx, center:right] = (right - np.arange(center, right)) / (right - center)
69
  # If left=center or center=right, the corresponding slope is zero, which is implicitly handled.
70
  # Ensure peak is 1.0 if center is a valid point within a slope.
71
+ if left <= center < right and center > left: # If center forms a peak of a valid triangle part
72
+ filterbank[m_idx, center] = 1.0
 
73
 
74
  return filterbank
75
 
 
82
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
83
  qformer_rate: int = DEFAULT_QFORMER_RATE,
84
  feat_stride: int = DEFAULT_FEAT_STRIDE,
85
+ sampling_rate: int = DEFAULT_SAMPLING_RATE, # Target sampling rate
86
  n_fft: int = DEFAULT_N_FFT,
87
  win_length: Optional[int] = None,
88
  hop_length: Optional[int] = None,
89
  n_mels: int = DEFAULT_N_MELS,
90
+ f_min: float = 0.0, # Added for mel filterbank control
91
+ f_max: Optional[float] = None, # Added for mel filterbank control
92
+ padding_value: float = 0.0, # Explicitly define for clarity
93
  **kwargs
94
  ):
95
  _win_length = win_length if win_length is not None else n_fft
 
98
  # feature_size is n_mels for the superclass
99
  super().__init__(
100
  feature_size=n_mels,
101
+ sampling_rate=sampling_rate, # This sets self.sampling_rate
102
  padding_value=padding_value,
103
  **kwargs
104
  )
 
113
  self.hop_length = _hop_length
114
  self.n_mels = n_mels
115
  self.f_min = f_min
116
+ self.f_max = f_max # Will be sampling_rate/2 if None in create_mel_filterbank call
117
 
118
  if self.win_length > self.n_fft:
119
  logger.warning(
120
  f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
121
  "Window will be applied, then data will be zero-padded/truncated to n_fft by np.fft.rfft."
122
  )
123
+ self.window = np.hamming(self.win_length).astype(
124
+ np.float32) # Or scipy.signal.get_window("hann", self.win_length)
125
  self.mel_filterbank = create_mel_filterbank(
126
  self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
127
+ ).T # Transpose for dot product: (n_fft // 2 + 1, n_mels)
 
128
 
129
  def __call__(
130
  self,
131
+ audios: Union[AudioInput, List[AudioInput]], # Accept single or list
132
+ sampling_rate: Optional[int] = None, # To specify SR if audios are raw arrays
133
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
134
  ) -> BatchFeature:
135
+
136
  if not isinstance(audios, list):
137
  audios = [audios]
138
 
139
  processed_mels: List[torch.Tensor] = []
140
  actual_mel_lengths: List[int] = []
141
+
142
  # Kept from user's code - their purpose might be for token calculation downstream
143
  sizes_for_embed_length: List[torch.Tensor] = []
144
  frames_scaled_by_feat_stride: List[int] = []
 
149
 
150
  if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
151
  current_wav, source_sr = audio_item
152
+ current_wav = np.asarray(current_wav, dtype=np.float32) # Ensure float32 numpy array
153
  elif isinstance(audio_item, (np.ndarray, list)):
154
  current_wav = np.asarray(audio_item, dtype=np.float32)
155
  if sampling_rate is None:
 
168
  f"Unsupported audio input type: {type(audio_item)}. "
169
  "Expected np.ndarray, list of floats, or Tuple[np.ndarray, int]."
170
  )
171
+
172
  processed_wav_array = self._preprocess_audio(current_wav, source_sr)
173
+ mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav_array) # Shape: (T_mel, N_Mels)
174
+
175
+ feature_tensor = torch.from_numpy(mel_spectrogram) # Already float32
176
  processed_mels.append(feature_tensor)
177
+ actual_mel_lengths.append(feature_tensor.shape[0]) # T_mel for this item
178
 
179
  # User's original logic for 'sizes' and 'frames'
180
  sizes_for_embed_length.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
 
186
 
187
  # Create attention mask corresponding to the actual lengths of mel spectrograms
188
  max_t_mel_in_batch = audio_embeds.shape[1]
189
+ current_device = audio_embeds.device # Get device from padded tensor if using PyTorch tensors earlier
190
+
191
  # Create attention mask directly based on actual_mel_lengths
192
  attention_mask = torch.zeros(len(audios), max_t_mel_in_batch, dtype=torch.bool, device=current_device)
193
  for i, length in enumerate(actual_mel_lengths):
194
  attention_mask[i, :length] = True
195
+
196
  output_data = {
197
  "audio_values": audio_embeds,
198
+ "audio_attention_mask": attention_mask # Correctly shaped mask for audio_values
199
  }
200
 
201
  # Include user's 'sizes' if they are needed downstream
 
209
  # Ensure wav is float32
210
  if wav.dtype not in [np.float32, np.float64]:
211
  if np.issubdtype(wav.dtype, np.integer):
212
+ max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0 # Avoid error on empty array
213
  wav = wav.astype(np.float32) / max_val
214
  else:
215
  wav = wav.astype(np.float32)
 
217
  wav = wav.astype(np.float32)
218
 
219
  if wav.ndim > 1:
220
+ wav = wav.mean(axis=0) # Convert to mono
221
+
222
  if source_sr != self.sampling_rate:
223
  logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
224
  # Calculate integer up/down factors for resample_poly
225
  common_divisor = math.gcd(self.sampling_rate, source_sr)
226
  up_factor = self.sampling_rate // common_divisor
227
  down_factor = source_sr // common_divisor
228
+ if up_factor != down_factor: # Only if actual resampling is needed
229
  wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
230
+
231
  # Normalize amplitude to roughly [-1, 1]
232
  max_abs_val = np.abs(wav).max()
233
+ if max_abs_val > 1e-7: # Avoid division by zero or tiny numbers
234
  wav = wav / max_abs_val
235
  return wav
236
 
 
243
  # Calculate number of frames
244
  # This calculation ensures at least one frame if len(wav) == self.win_length
245
  if len(wav) >= self.win_length:
246
+ num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
247
+ else: # Should be covered by padding, but as safeguard
248
+ num_frames = 0
249
+
250
  if num_frames <= 0:
251
  logger.warning(f"Audio is too short (length {len(wav)}) to produce any frames "
252
  f"with win_length {self.win_length} and hop_length {self.hop_length}. "
 
261
  strides=(strides * self.hop_length, strides),
262
  writeable=False
263
  )
264
+ frames_data = frames_view.copy() # Important: copy after as_strided if modifying
265
+
266
+ frames_data *= self.window # Apply window in-place on the copy
267
 
268
  # Compute STFT (rfft for real inputs)
269
  # n_fft determines zero-padding or truncation for FFT input from each frame
270
  spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
271
+ power = np.abs(spectrum) ** 2
272
+
273
+ mel_spectrogram = np.dot(power, self.mel_filterbank) # (num_frames, n_mels)
274
+
275
  # Clip and take log
276
+ mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None) # Use defined epsilon
277
  log_mel_spectrogram = np.log(mel_spectrogram)
278
+
279
  return log_mel_spectrogram.astype(np.float32)
280
 
281
  def _calculate_embed_length(self, frame_count: int) -> int:
 
284
  return math.ceil(compressed / self.qformer_rate)
285
 
286
 
287
+ class Gemma3ImagesKwargs(ImagesKwargs): # User's definition
288
  do_pan_and_scan: Optional[bool]
289
  pan_and_scan_min_crop_size: Optional[int]
290
  pan_and_scan_max_num_crops: Optional[int]
 
292
  do_convert_rgb: Optional[bool]
293
 
294
 
295
+ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): # User's definition
296
  images_kwargs: Dict[str, Any]
297
  audio_kwargs: Dict[str, Any]
298
  # Added text_kwargs as it's commonly part of such structures
299
+ text_kwargs: Optional[Dict[str, Any]] = None
300
  _defaults = {
301
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
302
  "images_kwargs": {},
 
306
 
307
  class Gemma3OmniProcessor(ProcessorMixin):
308
  attributes = ["image_processor", "audio_processor", "tokenizer"]
309
+ valid_kwargs = ["chat_template", "image_seq_length"] # From user's code
310
 
311
  # --- FIXED CLASS ATTRIBUTES ---
312
+ image_processor_class = "AutoImageProcessor" # As in user's original code
313
+ audio_processor_class = "AutoFeatureExtractor"
314
+ tokenizer_class = "AutoTokenizer" # As in user's original code
315
 
316
  def __init__(
317
  self,
318
+ image_processor=None, # Allow None, superclass or from_pretrained handles loading via _class
319
+ audio_processor=None, # Allow None or instance
320
+ tokenizer=None, # Allow None or instance
321
  chat_template=None,
322
  image_seq_length: int = 256,
323
+ **kwargs
324
  ):
325
  # The ProcessorMixin's __init__ will handle instantiating these if they are None,
326
  # using the respective *_class attributes.
327
  # If specific instances are passed, they will be used.
328
+
329
  # Retaining user's specific logic for setting attributes if needed,
330
  # though much of this might be handled by super() or better placed after super()
331
  self.image_seq_length = image_seq_length
332
+
333
  # These tokenizer-dependent attributes should be set *after* super().__init__
334
  # ensures self.tokenizer is populated, or if tokenizer is passed directly.
335
  # If tokenizer is None and loaded by super(), these need to be set post-super().
 
338
  # This is a basic placeholder; HF's from_pretrained mechanism is more robust for loading
339
  # For now, we'll assume if tokenizer is None, super() handles it or it's an error later.
340
  pass
341
+ else: # Tokenizer was provided
342
+ self.image_token_id = getattr(tokenizer, "image_token_id", None) # More robust with getattr
343
+ self.boi_token = getattr(tokenizer, "boi_token", "<|image|>") # Defaulting if not present
344
  self.image_token = getattr(tokenizer, "image_token", "<|image|>")
345
+ self.eoi_token = getattr(tokenizer, "eoi_token", "") # Added eoi_token as it was used
346
 
347
+ self.audio_token = "<audio_soft_token>" # User's definition
348
  # self.expected_audio_token_id = 262143 # User's reference
349
  # The existence of this token should be ensured when the tokenizer is prepared/saved.
350
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
351
  # if self.audio_token_id != self.expected_audio_token_id: # User's warning
352
  # logger.warning(...)
353
  if self.audio_token_id == tokenizer.unk_token_id:
354
+ logger.warning(
355
+ f"Audio token '{self.audio_token}' not found in tokenizer, maps to UNK. Ensure it's added.")
356
 
357
  self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token if hasattr(tokenizer, 'eoi_token') else ''}\n\n"
358
 
 
359
  # These seem specific to this processor's logic for determining audio token sequence length
360
  # It's better to initialize them here.
361
  self.audio_prompt_compression_rate = kwargs.pop("audio_prompt_compression_rate", 8)
362
  self.audio_prompt_qformer_rate = kwargs.pop("audio_prompt_qformer_rate", 1)
363
  self.audio_prompt_feat_stride = kwargs.pop("audio_prompt_feat_stride", 1)
364
 
 
365
  super().__init__(
366
  image_processor=image_processor,
367
  audio_processor=audio_processor,
368
  tokenizer=tokenizer,
369
  chat_template=chat_template,
370
+ **kwargs # Pass remaining kwargs to super
371
  )
372
+
373
  # If tokenizer was loaded by super(), set tokenizer-dependent attributes now
374
  if not hasattr(self, 'image_token_id') and self.tokenizer is not None:
375
+ self.image_token_id = getattr(self.tokenizer, "image_token_id",
376
+ self.tokenizer.unk_token_id if hasattr(self.tokenizer,
377
+ "unk_token_id") else None)
378
  self.boi_token = getattr(self.tokenizer, "boi_token", "<|image|>")
379
  self.image_token = getattr(self.tokenizer, "image_token", "<|image|>")
380
  self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
381
  self.audio_token = "<audio_soft_token>"
382
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token)
383
  if self.audio_token_id == self.tokenizer.unk_token_id:
384
+ logger.warning(
385
+ f"Audio token '{self.audio_token}' not found in tokenizer (post-super), maps to UNK. Ensure it's added.")
386
  self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}\n\n"
387
 
 
388
  def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs_from_call):
389
  # User's original _merge_kwargs logic
390
  default_kwargs = {}
 
398
 
399
  for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
400
  if modality_key_in_call in default_kwargs:
401
+ if isinstance(modality_kwargs_in_call, dict):
402
  default_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
403
+ elif isinstance(modality_kwargs_in_call, dict): # New modality not in defaults
404
+ default_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
 
405
 
406
  # Update defaults with tokenizer init kwargs (original logic)
407
+ for modality_key in default_kwargs: # Iterate over current keys in default_kwargs
408
  modality_dict = default_kwargs[modality_key]
409
+ if isinstance(modality_dict, dict): # Ensure it's a dict before trying to access keys
410
+ for key_in_mod_dict in list(modality_dict.keys()): # Iterate over copy of keys
411
  if key_in_mod_dict in tokenizer_init_kwargs:
412
  value = (
413
  getattr(self.tokenizer, key_in_mod_dict)
 
415
  else tokenizer_init_kwargs[key_in_mod_dict]
416
  )
417
  modality_dict[key_in_mod_dict] = value
418
+
419
  # Ensure text_kwargs processing (original logic)
420
+ if "text_kwargs" not in default_kwargs: # Ensure text_kwargs exists
421
  default_kwargs["text_kwargs"] = {}
422
  default_kwargs["text_kwargs"]["truncation"] = default_kwargs["text_kwargs"].get("truncation", False)
423
+ default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length",
424
+ DEFAULT_MAX_LENGTH)
425
 
426
  return default_kwargs
427
 
 
433
  def __call__(
434
  self,
435
  images=None,
436
+ text: Union[str, List[str]] = None, # text is optional but often primary
437
  # videos=None, # Removed 'videos' as it's not handled
438
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
439
+ sampling_rate: Optional[int] = None, # For audio_processor if audios are raw arrays
440
  return_tensors: Optional[Union[str, TensorType]] = None,
441
+ **kwargs: Any # Replaced Unpack for broader compatibility here
442
  ) -> BatchFeature:
443
+ if text is None and images is None and audios is None: # Added audios to check
444
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
445
 
446
  # Determine final return_tensors strategy
 
449
  # This call to _merge_kwargs primarily populates kwargs for each modality if passed in __call__
450
  # e.g. if user calls proc(..., text_kwargs={...})
451
  merged_call_kwargs = self._merge_kwargs(
452
+ Gemma3ProcessorKwargs,
453
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, "init_kwargs") else {},
454
  **kwargs
455
  )
456
+
457
  # If return_tensors wasn't passed to __call__, try to get it from merged text_kwargs
458
  # and remove it from there to avoid passing it twice to tokenizer.
459
  # Default to PYTORCH if still None.
 
462
  else:
463
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
464
 
 
465
  # Standardize text input
466
+ if text is None: # If no text given, create dummy text based on other modalities
467
  num_samples = 0
468
  if images is not None:
469
+ _images_list = images if isinstance(images, list) and (
470
+ not images or not isinstance(images[0], (int, float))) else [images]
471
  num_samples = len(_images_list)
472
  elif audios is not None:
473
  _audios_list = audios if isinstance(audios, list) else [audios]
474
  num_samples = len(_audios_list)
475
+ text = [""] * num_samples if num_samples > 0 else [""] # Fallback for safety
476
 
477
  if isinstance(text, str):
478
  text = [text]
 
482
  # --- Image Processing ---
483
  image_features_dict = {}
484
  if images is not None and self.image_processor is not None:
485
+ batched_images = make_nested_list_of_images(images) # HF utility
486
  # Assuming image_processor returns a dict or BatchFeature. If BatchFeature, get .data
487
+ _img_proc_output = self.image_processor(batched_images, return_tensors=None,
488
+ **merged_call_kwargs.get("images_kwargs", {}))
489
+ image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
490
+ BatchFeature) else _img_proc_output
491
 
492
+ if len(batched_images) != len(text): # Validate batch consistency
493
  raise ValueError(f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts")
494
 
495
  # User's original image token replacement logic (complex, depends on num_crops etc from image_processor output)
496
  # This part needs to be carefully adapted based on actual image_processor output structure
497
  # For now, a simplified placeholder for the concept:
498
+ if "num_crops" in image_features_dict: # Example check
499
  num_crops_list = to_py_obj(image_features_dict.pop("num_crops"))
500
  # ... user's original logic for text modification with self.full_image_sequence ...
501
  # This was: text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
 
503
  # For simplicity, assuming one image sequence per text for now if an image is present.
504
  temp_text = []
505
  for i, prompt in enumerate(text):
506
+ if i < len(batched_images): # if this text sample has corresponding images
507
+ # Replace first boi_token or append if not found
508
  if self.boi_token in prompt:
509
  temp_text.append(prompt.replace(self.boi_token, self.full_image_sequence, 1))
510
  else:
 
513
  temp_text.append(prompt)
514
  text = temp_text
515
 
 
516
  # --- Audio Processing ---
517
  audio_features_dict = {}
518
  if audios is not None and self.audio_processor is not None:
519
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
520
  if sampling_rate is not None:
521
+ audio_call_kwargs["sampling_rate"] = sampling_rate
522
+
523
  # audio_processor.__call__ returns BatchFeature, get its .data attribute for the dict
524
  _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
525
  audio_features_dict = _audio_proc_output.data
 
530
  actual_mel_frames_per_sample = to_py_obj(audio_features_dict["audio_attention_mask"].sum(axis=-1))
531
 
532
  if len(actual_mel_frames_per_sample) != len(text):
533
+ raise ValueError(
534
+ f"Inconsistent batch sizes for audio and text: {len(actual_mel_frames_per_sample)} audio samples, {len(text)} texts.")
535
 
536
  for i, prompt in enumerate(text):
537
  num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
538
+ audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens # Repeat soft token string
539
+
540
  # Replace a placeholder or append
541
+ placeholder = getattr(self, "audio_placeholder_token", "<|audio|>") # Use defined placeholder
542
  if placeholder in prompt:
543
  prompt_with_audio = prompt.replace(placeholder, audio_token_sequence_str, 1)
544
+ else:
545
+ prompt_with_audio = prompt + audio_token_sequence_str
546
  new_text_with_audio_tokens.append(prompt_with_audio)
547
  text = new_text_with_audio_tokens
548
+
549
  # --- Text Tokenization ---
550
  text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
551
  # Tokenize the (potentially modified) text, request lists/np arrays
 
560
  if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
561
  input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)
562
  elif isinstance(input_ids_list_of_lists, list) and \
563
+ (not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
564
+ input_ids_list_of_lists = [input_ids_list_of_lists] # Batch of 1
565
 
566
  mm_token_type_ids_list = []
567
  for ids_sample in input_ids_list_of_lists:
568
+ type_ids_sample = [0] * len(ids_sample) # Default type 0 (text)
569
  for idx, token_id_val in enumerate(ids_sample):
570
  if self.image_token_id is not None and token_id_val == self.image_token_id:
571
+ type_ids_sample[idx] = 1 # Image token type
572
+ elif token_id_val == self.audio_token_id: # Compare with ID of <audio_soft_token>
573
+ type_ids_sample[idx] = 2 # Audio token type
574
  mm_token_type_ids_list.append(type_ids_sample)
575
  text_features_dict["token_type_ids"] = mm_token_type_ids_list
576
+
577
  # Combine all features
578
  final_batch_data = {**text_features_dict}
579
+ if image_features_dict:
580
  final_batch_data.update(image_features_dict)
581
+ if audio_features_dict:
582
  final_batch_data.update(audio_features_dict)
583
+
584
+ return BatchFeature(data=final_batch_data, tensor_type=final_rt) # Use determined final_rt
585
 
586
  def batch_decode(self, *args, **kwargs):
587
  return self.tokenizer.batch_decode(*args, **kwargs)
 
593
  def model_input_names(self):
594
  tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
595
  image_processor_inputs = []
596
+ if self.image_processor is not None: # Check if image_processor exists
597
+ image_processor_inputs = self.image_processor.model_input_names
598
+
599
  audio_processor_inputs = []
600
+ if self.audio_processor is not None: # Check if audio_processor exists
601
  # These are the keys Gemma3AudioFeatureExtractor puts in its output BatchFeature.data
602
  audio_processor_inputs = ["audio_values", "audio_attention_mask"]
603
  # "audio_values_sizes" was in user's original Gemma3AudioFeatureExtractor output,