voidful commited on
Commit
9faac02
·
verified ·
1 Parent(s): f1bb3f9

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +406 -305
processing_gemma3_omni.py CHANGED
@@ -1,30 +1,24 @@
1
  import re
2
- from typing import List, Optional, Union, Dict, Any, Tuple
3
 
4
  import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
-
10
- # Assuming AudioInput might be from transformers.audio_utils for full robustness,
11
- # but for now, let's define a clear supported set.
12
- # from transformers.audio_utils import AudioInput as HfAudioInput, load_audio
13
- # For this fix, we define AudioInput locally for clarity on what's handled.
14
- AudioInput = Union[np.ndarray, List[float], Tuple[np.ndarray, int]]
15
-
16
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
17
  from transformers.feature_extraction_utils import BatchFeature
18
- from transformers.processing_utils import ProcessorMixin, ProcessingKwargs
 
19
  from transformers.utils import TensorType, to_py_obj, logging
20
- # For AutoImageProcessor, AutoTokenizer if needed for default loading
21
- from transformers import AutoImageProcessor, AutoTokenizer
22
 
23
- # Constants (as defined before)
24
  DEFAULT_SAMPLING_RATE = 16000
25
  DEFAULT_N_FFT = 512
26
- DEFAULT_WIN_LENGTH = 400 # Will be n_fft if None in __init__
27
- DEFAULT_HOP_LENGTH = 160 # Will be win_length // 4 if None in __init__
28
  DEFAULT_N_MELS = 80
29
  DEFAULT_COMPRESSION_RATE = 4
30
  DEFAULT_QFORMER_RATE = 2
@@ -32,59 +26,56 @@ DEFAULT_FEAT_STRIDE = 4
32
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
33
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
34
  DEFAULT_MAX_LENGTH = 16384
35
- LOG_MEL_CLIP_EPSILON = 1e-5
36
 
37
  logger = logging.get_logger(__name__)
38
 
39
 
40
- # create_mel_filterbank function (assuming it's correctly defined from previous response)
41
- # ... (create_mel_filterbank function from the previous corrected response) ...
42
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
43
  fmax: Optional[float] = None) -> np.ndarray:
44
- """Create Mel filterbank for audio processing."""
45
- fmax = fmax or sampling_rate / 2.0
46
 
47
- if fmin >= fmax:
48
- raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
 
 
 
 
49
 
50
- def hz_to_mel(f: float) -> float: # Using HTK formula (as in librosa default)
51
- return 2595.0 * math.log10(1 + f / 700.0)
52
-
53
- def mel_to_hz(mel: float) -> float:
54
- return 700.0 * (10 ** (mel / 2595.0) - 1)
55
 
56
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
57
- freq_points = mel_to_hz(mel_points)
 
58
 
 
59
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
60
- bins = np.floor((n_fft / 2.0) * freq_points / (sampling_rate / 2.0)).astype(int)
 
 
61
  bins = np.clip(bins, 0, n_fft // 2)
62
 
63
- filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
64
- for m in range(n_mels):
65
- left, center, right = bins[m], bins[m + 1], bins[m + 2]
66
 
67
- # Simplified triangle creation logic (more robust versions exist in libraries like librosa)
 
 
 
 
 
 
68
  if center > left:
69
- filterbank[m, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
70
  if right > center:
71
- filterbank[m, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
72
- # Ensure peak is 1 if multiple points coincide at center (can happen with narrow filters/low resolution)
73
- if left <= center <= right and filterbank[m, center] < 1.0 and (
74
- center > left or center < right): # check if it's a valid point for a peak
75
- # if filterbank[m,center] is not properly set to 1 by slopes (e.g. left==center or right==center)
76
- filterbank[m, center] = 1.0
77
- if left == center and right > center: # only falling slope
78
- # Ensure it doesn't double-dip if already set
79
- pass
80
- elif right == center and left < center: # only rising slope
81
- pass
82
 
83
  return filterbank
84
 
85
 
86
- # Gemma3AudioFeatureExtractor class (assuming it's correctly defined from previous response)
87
- # ... (Gemma3AudioFeatureExtractor class from the previous corrected response) ...
88
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
89
  model_input_names = ["audio_values", "audio_attention_mask"]
90
 
@@ -93,168 +84,221 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
93
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
94
  qformer_rate: int = DEFAULT_QFORMER_RATE,
95
  feat_stride: int = DEFAULT_FEAT_STRIDE,
96
- sampling_rate: int = DEFAULT_SAMPLING_RATE,
97
  n_fft: int = DEFAULT_N_FFT,
98
  win_length: Optional[int] = None,
99
  hop_length: Optional[int] = None,
100
  n_mels: int = DEFAULT_N_MELS,
101
- f_min: float = 0.0,
102
- f_max: Optional[float] = None,
103
- padding_value: float = 0.0,
104
  **kwargs
105
  ):
106
- kwargs.pop("feature_size", None)
107
- kwargs.pop("sampling_rate", None)
108
- kwargs.pop("padding_value", None)
109
 
 
110
  super().__init__(
111
  feature_size=n_mels,
112
- sampling_rate=sampling_rate,
113
- padding_value=0.0,
114
  **kwargs
115
  )
116
-
117
  self.compression_rate = compression_rate
118
  self.qformer_rate = qformer_rate
119
  self.feat_stride = feat_stride
 
 
120
  self.n_fft = n_fft
121
- self.win_length = win_length if win_length is not None else n_fft
122
- self.hop_length = hop_length if hop_length is not None else self.win_length // 4
123
  self.n_mels = n_mels
124
  self.f_min = f_min
125
- self.f_max = f_max if f_max is not None else self.sampling_rate / 2.0
126
 
127
  if self.win_length > self.n_fft:
128
  logger.warning(
129
  f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
130
- f"For FFT computation, the window will effectively be truncated or the signal zero-padded to n_fft length."
131
  )
132
- self.window = scipy.signal.get_window("hann", self.win_length).astype(np.float32)
133
  self.mel_filterbank = create_mel_filterbank(
134
  self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
135
- ).T
 
136
 
137
  def __call__(
138
  self,
139
- audios: Union[AudioInput, List[AudioInput]],
140
- sampling_rate: Optional[int] = None,
141
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
142
  ) -> BatchFeature:
143
-
144
  if not isinstance(audios, list):
145
  audios = [audios]
146
 
147
- processed_mel_spectrograms: List[torch.Tensor] = []
148
  actual_mel_lengths: List[int] = []
149
- downstream_sizes_for_token_calc: List[torch.Tensor] = []
150
- downstream_frames_scaled_for_token_calc: List[int] = []
 
 
151
 
152
- for audio_input_item in audios:
153
- current_wav_array: np.ndarray
154
  source_sr: int
155
 
156
- if isinstance(audio_input_item, tuple):
157
- current_wav_array, source_sr = audio_input_item
158
- current_wav_array = np.asarray(current_wav_array, dtype=np.float32)
159
- elif isinstance(audio_input_item, (np.ndarray, list)):
160
- current_wav_array = np.asarray(audio_input_item, dtype=np.float32)
161
  if sampling_rate is None:
162
  raise ValueError(
163
- "sampling_rate must be provided if audio inputs are raw numpy arrays or lists."
164
  )
165
  source_sr = sampling_rate
 
 
 
 
 
 
166
  else:
167
  raise TypeError(
168
- f"Unsupported audio input type: {type(audio_input_item)}. "
169
- "This extractor expects np.ndarray, list of floats, or Tuple[np.ndarray, int indicating SR]."
170
  )
171
-
172
- processed_wav = self._preprocess_audio(current_wav_array, source_sr)
173
- mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
174
-
175
- feature_tensor = torch.from_numpy(mel_spectrogram)
176
- processed_mel_spectrograms.append(feature_tensor)
177
- actual_mel_lengths.append(feature_tensor.shape[0])
178
-
179
- downstream_sizes_for_token_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
180
- downstream_frames_scaled_for_token_calc.append(feature_tensor.shape[0] * self.feat_stride)
181
-
182
- audio_values = pad_sequence(processed_mel_spectrograms, batch_first=True, padding_value=self.padding_value)
183
- max_mel_len = audio_values.shape[1]
184
- lengths_tensor = torch.tensor(actual_mel_lengths, dtype=torch.long)
185
- audio_attention_mask = torch.arange(max_mel_len).unsqueeze(0).expand(len(audios),
186
- -1) < lengths_tensor.unsqueeze(1)
187
-
 
 
 
 
 
 
 
 
188
  output_data = {
189
- "audio_values": audio_values,
190
- "audio_attention_mask": audio_attention_mask
191
  }
192
 
193
- if downstream_sizes_for_token_calc:
194
- output_data["audio_token_calc_sizes"] = torch.stack(downstream_sizes_for_token_calc)
 
 
195
 
196
  return BatchFeature(data=output_data, tensor_type=return_tensors)
197
 
198
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
 
199
  if wav.dtype not in [np.float32, np.float64]:
200
  if np.issubdtype(wav.dtype, np.integer):
201
- max_val = np.iinfo(wav.dtype).max
202
  wav = wav.astype(np.float32) / max_val
203
  else:
204
  wav = wav.astype(np.float32)
 
 
205
 
206
  if wav.ndim > 1:
207
- wav = wav.mean(axis=0)
208
-
209
  if source_sr != self.sampling_rate:
210
- gcd = math.gcd(self.sampling_rate, source_sr)
211
- up_factor = self.sampling_rate // gcd
212
- down_factor = source_sr // gcd
213
- if up_factor != down_factor:
214
- logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
 
215
  wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
216
-
217
- norm_factor = np.abs(wav).max()
218
- if norm_factor > 1e-9:
219
- wav = wav / norm_factor
 
220
  return wav
221
 
222
  def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
223
  if len(wav) < self.win_length:
 
224
  padding = self.win_length - len(wav)
225
  wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
226
 
227
- num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
 
 
 
 
 
 
228
  if num_frames <= 0:
229
- logger.warning(
230
- f"Audio of length {len(wav)} is too short to produce frames with win_length {self.win_length} and hop_length {self.hop_length}. Returning empty mel spectrogram.")
 
231
  return np.zeros((0, self.n_mels), dtype=np.float32)
232
 
233
- frames = np.lib.stride_tricks.as_strided(
 
 
234
  wav,
235
  shape=(num_frames, self.win_length),
236
- strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
237
  writeable=False
238
  )
 
 
 
239
 
240
- windowed_frames = frames * self.window
241
- stft_matrix = np.fft.rfft(windowed_frames, n=self.n_fft, axis=-1)
242
- powers = np.abs(stft_matrix) ** 2
243
- mel_spectrogram = np.dot(powers, self.mel_filterbank)
244
- mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
 
 
 
 
245
  log_mel_spectrogram = np.log(mel_spectrogram)
246
-
247
  return log_mel_spectrogram.astype(np.float32)
248
 
249
  def _calculate_embed_length(self, frame_count: int) -> int:
 
250
  compressed = math.ceil(frame_count / self.compression_rate)
251
  return math.ceil(compressed / self.qformer_rate)
252
 
253
 
254
- class Gemma3DummyProcessorKwargs(ProcessingKwargs, total=False): # Dummy for testing structure
 
 
 
 
 
 
 
 
255
  images_kwargs: Dict[str, Any]
256
  audio_kwargs: Dict[str, Any]
257
- text_kwargs: Dict[str, Any]
 
258
  _defaults = {
259
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
260
  "images_kwargs": {},
@@ -264,230 +308,282 @@ class Gemma3DummyProcessorKwargs(ProcessingKwargs, total=False): # Dummy for te
264
 
265
  class Gemma3OmniProcessor(ProcessorMixin):
266
  attributes = ["image_processor", "audio_processor", "tokenizer"]
267
- valid_kwargs = ["chat_template", "image_seq_length"]
268
- image_processor_class = "AutoImageProcessor"
269
- audio_processor_class = "AutoFeatureExtractor"
270
- tokenizer_class = "AutoTokenizer"
271
 
272
- # valid_kwargs was in user's code, its role depends on ProcessorMixin internal usage
273
- valid_kwargs = ["chat_template", "image_seq_length"]
 
 
274
 
275
  def __init__(
276
  self,
277
- tokenizer,
278
- audio_processor: Optional[Union[Gemma3AudioFeatureExtractor, Dict]] = None,
279
- image_processor=None,
280
  chat_template=None,
281
  image_seq_length: int = 256,
282
- audio_prompt_compression_rate: int = 8,
283
- audio_prompt_qformer_rate: int = 1,
284
- audio_prompt_feat_stride: int = 1,
285
- audio_placeholder_token: str = "<|audio_placeholder|>",
286
- audio_soft_token_str: str = "<audio_soft_token>",
287
- **kwargs
288
  ):
289
- # Instantiate audio_processor if config dict is passed or if None (use defaults)
290
- if audio_processor is None:
291
- logger.info("Initializing Gemma3AudioFeatureExtractor with default parameters for Gemma3OmniProcessor.")
292
- audio_processor = Gemma3AudioFeatureExtractor()
293
- elif isinstance(audio_processor, Dict):
294
- audio_processor = Gemma3AudioFeatureExtractor(**audio_processor)
295
- elif not isinstance(audio_processor, Gemma3AudioFeatureExtractor): # Check type if instance is passed
296
- raise TypeError(
297
- f"audio_processor must be an instance of Gemma3AudioFeatureExtractor or a config dict, got {type(audio_processor)}")
298
-
299
- # Handle image_processor similarly if it can be None or a dict
300
- if image_processor is None and self.image_processor_class:
301
- # This is a basic way; from_pretrained usually handles complex loading
302
- if isinstance(self.image_processor_class, str) and self.image_processor_class == "AutoImageProcessor":
303
- logger.info(
304
- f"Attempting to load a default {self.image_processor_class}. This might require a default model name or fail.")
305
- # image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") # Example default
306
- # else if self.image_processor_class is an actual class, instantiate it.
307
- elif isinstance(image_processor, Dict):
308
- # image_processor = AutoImageProcessor.from_config(config_class(**image_processor)) # Example
309
- pass # Actual instantiation from dict would be more complex
310
-
311
- # Ensure tokenizer is an instantiated object
312
- if isinstance(tokenizer, str): # If tokenizer is a string (model name/path)
313
- logger.info(f"Loading tokenizer from {tokenizer}")
314
- # tokenizer = AutoTokenizer.from_pretrained(tokenizer) # This is how it's usually done
315
- elif tokenizer is None:
316
- raise ValueError("A tokenizer instance or identifier must be provided.")
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  super().__init__(
319
  image_processor=image_processor,
320
  audio_processor=audio_processor,
321
  tokenizer=tokenizer,
322
  chat_template=chat_template,
323
- **kwargs # Pass other kwargs to super
324
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
- self.image_seq_length = image_seq_length
327
- self.image_token_id = getattr(self.tokenizer, "image_token_id",
328
- self.tokenizer.unk_token_id if hasattr(self.tokenizer, "unk_token_id") else None)
329
- self.boi_token = getattr(self.tokenizer, "boi_token", "<|image|>")
330
- self.image_token = getattr(self.tokenizer, "image_token", "<|image|>")
331
- self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
332
-
333
- self.audio_placeholder_token = audio_placeholder_token
334
- self.audio_soft_token_str = audio_soft_token_str
335
 
336
- self.audio_soft_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_soft_token_str)
337
- if self.audio_soft_token_id == self.tokenizer.unk_token_id: # Check if UNK
338
- logger.warning(
339
- f"The audio soft token string '{self.audio_soft_token_str}' maps to UNK token (ID: {self.audio_soft_token_id}). "
340
- "Ensure it is added to the tokenizer's vocabulary as a special token."
341
- )
342
 
343
- self.full_image_sequence_str = f"\n\n{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}\n\n"
344
-
345
- self.audio_prompt_compression_rate = audio_prompt_compression_rate
346
- self.audio_prompt_qformer_rate = audio_prompt_qformer_rate
347
- self.audio_prompt_feat_stride = audio_prompt_feat_stride
348
-
349
- def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_passed_to_call):
350
- final_kwargs = {}
351
- # Initialize with _defaults from the Kwargs class
352
- # Ensure KwargsClassWithDefaults has a _defaults attribute
353
- _defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
354
- for modality_key, default_modality_kwargs in _defaults.items():
355
- final_kwargs[modality_key] = default_modality_kwargs.copy()
356
-
357
- # Override with tokenizer's init_kwargs if they exist for a given key
358
- for modality_key, modality_dict in final_kwargs.items():
359
- for key in list(modality_dict.keys()):
360
- if key in tokenizer_init_kwargs:
361
- modality_dict[key] = tokenizer_init_kwargs[key]
362
-
363
- # Override with kwargs passed directly to __call__
364
- for modality_key_from_call, modality_dict_from_call in kwargs_passed_to_call.items():
365
- if modality_key_from_call in final_kwargs and isinstance(modality_dict_from_call, dict):
366
- final_kwargs[modality_key_from_call].update(modality_dict_from_call)
367
- # If a new modality_kwargs (e.g., "video_kwargs") is passed, add it
368
- elif modality_key_from_call not in final_kwargs and isinstance(modality_dict_from_call, dict):
369
- final_kwargs[modality_key_from_call] = modality_dict_from_call.copy()
370
-
371
- # Specific handling for text_kwargs
372
- if "text_kwargs" not in final_kwargs:
373
- final_kwargs["text_kwargs"] = {} # Ensure it exists
374
- final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
375
- final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
376
-
377
- return final_kwargs
378
-
379
- def _compute_audio_prompt_token_count(self, actual_mel_frames_count: int) -> int:
380
- scaled_frames = actual_mel_frames_count * self.audio_prompt_feat_stride
381
- compressed_once = math.ceil(scaled_frames / self.audio_prompt_compression_rate)
382
- compressed_twice = math.ceil(compressed_once / self.audio_prompt_qformer_rate)
383
- return compressed_twice
384
 
385
  def __call__(
386
  self,
387
- text: Union[str, List[str]] = None,
388
- images: Optional[Any] = None,
 
389
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
390
- sampling_rate: Optional[int] = None,
391
  return_tensors: Optional[Union[str, TensorType]] = None,
392
- **kwargs: Any
393
  ) -> BatchFeature:
394
-
395
- if text is None and images is None and audios is None:
396
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
397
 
398
  # Determine final return_tensors strategy
399
- # Priority: 1. Explicit return_tensors, 2. from text_kwargs in **kwargs, 3. Default (PT)
400
  final_rt = return_tensors
 
 
 
401
  merged_call_kwargs = self._merge_kwargs(
402
- Gemma3DummyProcessorKwargs, # Using dummy for _defaults structure
403
- self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
404
  **kwargs
405
  )
406
-
407
- if final_rt is None: # If not passed directly to __call__
 
 
 
408
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
409
- else: # If passed directly, remove from text_kwargs to avoid conflict
410
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
411
 
412
- if text is None:
 
 
413
  num_samples = 0
414
  if images is not None:
415
- _images_list = images if isinstance(images, list) and (
416
- not images or not isinstance(images[0], (int, float))) else [images]
417
  num_samples = len(_images_list)
418
  elif audios is not None:
419
  _audios_list = audios if isinstance(audios, list) else [audios]
420
  num_samples = len(_audios_list)
421
- text = [""] * num_samples if num_samples > 0 else [""]
422
 
423
  if isinstance(text, str):
424
  text = [text]
425
- if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
426
- raise ValueError("Input `text` must be a string or a list of strings.")
427
 
 
428
  image_features_dict = {}
429
  if images is not None and self.image_processor is not None:
430
- logger.info("Processing images...")
431
- # image_features_dict = self.image_processor(images, return_tensors=None, **merged_call_kwargs.get("images_kwargs", {}))
432
- # Simplified: Actual image token replacement logic for `text` would go here.
433
- # text = self._handle_image_text_replacement(text, images, image_features_dict)
434
- pass
435
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  audio_features_dict = {}
437
  if audios is not None and self.audio_processor is not None:
438
- logger.info("Processing audio...")
439
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
440
- if sampling_rate:
441
- audio_call_kwargs["sampling_rate"] = sampling_rate
 
 
 
 
442
 
443
- # audio_processor.__call__ returns BatchFeature, we need its .data attribute
444
- audio_features_batch_feature = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
445
- audio_features_dict = audio_features_batch_feature.data # Get the dict
 
446
 
447
- new_text_with_audio = []
448
- # audio_attention_mask shape is (B, Max_T_mel)
449
- audio_sample_mel_lengths = to_py_obj(audio_features_dict["audio_attention_mask"].sum(axis=-1))
450
 
451
  for i, prompt in enumerate(text):
452
- num_soft_tokens = self._compute_audio_prompt_token_count(audio_sample_mel_lengths[i])
453
- audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens
454
-
455
- if self.audio_placeholder_token in prompt:
456
- prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
457
- else:
458
- prompt += audio_token_sequence_str
459
- new_text_with_audio.append(prompt)
460
- text = new_text_with_audio
461
-
462
- logger.info("Tokenizing text...")
463
- text_call_kwargs = merged_call_kwargs.get("text_kwargs", {})
464
- text_features_dict = self.tokenizer(text, return_tensors=None, **text_call_kwargs)
465
-
466
- input_ids_list = text_features_dict["input_ids"]
467
- if not isinstance(input_ids_list, list) or not (input_ids_list and isinstance(input_ids_list[0], list)):
468
- if isinstance(input_ids_list, (torch.Tensor, np.ndarray)):
469
- input_ids_list = to_py_obj(input_ids_list) # Convert tensor/np.array to list of lists
470
- elif isinstance(input_ids_list, list) and (not input_ids_list or isinstance(input_ids_list[0], int)):
471
- input_ids_list = [input_ids_list]
472
-
473
- token_type_ids_list = []
474
- for ids_sample in input_ids_list:
475
- types = [0] * len(ids_sample)
476
- for j, token_id in enumerate(ids_sample):
477
- if self.image_token_id is not None and token_id == self.image_token_id:
478
- types[j] = 1
479
- elif token_id == self.audio_soft_token_id:
480
- types[j] = 2
481
- token_type_ids_list.append(types)
482
- text_features_dict["token_type_ids"] = token_type_ids_list
483
-
484
- combined_features = {**text_features_dict}
485
- if image_features_dict:
486
- combined_features.update(image_features_dict)
487
- if audio_features_dict:
488
- combined_features.update(audio_features_dict)
489
-
490
- return BatchFeature(data=combined_features, tensor_type=final_rt)
 
 
 
 
 
 
 
 
 
491
 
492
  def batch_decode(self, *args, **kwargs):
493
  return self.tokenizer.batch_decode(*args, **kwargs)
@@ -496,12 +592,17 @@ class Gemma3OmniProcessor(ProcessorMixin):
496
  return self.tokenizer.decode(*args, **kwargs)
497
 
498
  @property
499
- def model_input_names(self) -> List[str]:
500
- input_names = set(self.tokenizer.model_input_names + ["token_type_ids"])
501
- if self.image_processor is not None:
502
- input_names.update(self.image_processor.model_input_names)
503
- if self.audio_processor is not None:
504
- # From Gemma3AudioFeatureExtractor's output_data keys
505
- input_names.update(["audio_values", "audio_attention_mask"])
506
- # "audio_token_calc_sizes" is internal to processor, not model.
507
- return list(input_names)
 
 
 
 
 
 
1
  import re
2
+ from typing import List, Optional, Union, Dict, Any
3
 
4
  import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
+ # Using the original AudioInput for minimal change from your provided code
10
+ from transformers.audio_utils import AudioInput # type: ignore
 
 
 
 
 
11
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
12
  from transformers.feature_extraction_utils import BatchFeature
13
+ from transformers.image_utils import make_nested_list_of_images
14
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs # Removed Unpack as it's not standard
15
  from transformers.utils import TensorType, to_py_obj, logging
 
 
16
 
17
+ # Constants
18
  DEFAULT_SAMPLING_RATE = 16000
19
  DEFAULT_N_FFT = 512
20
+ DEFAULT_WIN_LENGTH = 400
21
+ DEFAULT_HOP_LENGTH = 160
22
  DEFAULT_N_MELS = 80
23
  DEFAULT_COMPRESSION_RATE = 4
24
  DEFAULT_QFORMER_RATE = 2
 
26
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
27
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
28
  DEFAULT_MAX_LENGTH = 16384
29
+ LOG_MEL_CLIP_EPSILON = 1e-5 # Epsilon for log mel clipping
30
 
31
  logger = logging.get_logger(__name__)
32
 
33
 
 
 
34
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
35
  fmax: Optional[float] = None) -> np.ndarray:
36
+ """Create Mel filterbank for audio processing. (User's version)"""
37
+ fmax = fmax or sampling_rate / 2.0 # Ensure float division
38
 
39
+ # User's Mel scale formula
40
+ def hz_to_mel(f: float) -> float:
41
+ return 1127.0 * math.log(1 + f / 700.0)
42
+
43
+ def mel_to_hz(mel: float) -> float: # Added for completeness if needed
44
+ return 700.0 * (math.exp(mel / 1127.0) - 1)
45
 
 
 
 
 
 
46
 
47
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
48
+ # freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Original
49
+ freq_points = mel_to_hz(mel_points) # Using the inverse function
50
 
51
+ # Clip freq_points to be within [0, sampling_rate/2]
52
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
53
+
54
+ bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
55
+ # Ensure bins are within valid range for rfft output indices
56
  bins = np.clip(bins, 0, n_fft // 2)
57
 
 
 
 
58
 
59
+ filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
60
+ for m_idx in range(n_mels): # Loop from 0 to n_mels-1 to fill filterbank[m_idx]
61
+ # Bins for (m_idx)-th filter are bins[m_idx], bins[m_idx+1], bins[m_idx+2]
62
+ left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
63
+
64
+ # Original logic for applying triangular filter
65
+ # Ensure no division by zero if points coincide
66
  if center > left:
67
+ filterbank[m_idx, left:center] = (np.arange(left, center) - left) / (center - left)
68
  if right > center:
69
+ filterbank[m_idx, center:right] = (right - np.arange(center, right)) / (right - center)
70
+ # If left=center or center=right, the corresponding slope is zero, which is implicitly handled.
71
+ # Ensure peak is 1.0 if center is a valid point within a slope.
72
+ if left <= center < right and center > left : # If center forms a peak of a valid triangle part
73
+ filterbank[m_idx, center] = 1.0
74
+
 
 
 
 
 
75
 
76
  return filterbank
77
 
78
 
 
 
79
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
80
  model_input_names = ["audio_values", "audio_attention_mask"]
81
 
 
84
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
85
  qformer_rate: int = DEFAULT_QFORMER_RATE,
86
  feat_stride: int = DEFAULT_FEAT_STRIDE,
87
+ sampling_rate: int = DEFAULT_SAMPLING_RATE, # Target sampling rate
88
  n_fft: int = DEFAULT_N_FFT,
89
  win_length: Optional[int] = None,
90
  hop_length: Optional[int] = None,
91
  n_mels: int = DEFAULT_N_MELS,
92
+ f_min: float = 0.0, # Added for mel filterbank control
93
+ f_max: Optional[float] = None, # Added for mel filterbank control
94
+ padding_value: float = 0.0, # Explicitly define for clarity
95
  **kwargs
96
  ):
97
+ _win_length = win_length if win_length is not None else n_fft
98
+ _hop_length = hop_length if hop_length is not None else _win_length // 4
 
99
 
100
+ # feature_size is n_mels for the superclass
101
  super().__init__(
102
  feature_size=n_mels,
103
+ sampling_rate=sampling_rate, # This sets self.sampling_rate
104
+ padding_value=padding_value,
105
  **kwargs
106
  )
107
+
108
  self.compression_rate = compression_rate
109
  self.qformer_rate = qformer_rate
110
  self.feat_stride = feat_stride
111
+ # self.sampling_rate is now set by super()
112
+
113
  self.n_fft = n_fft
114
+ self.win_length = _win_length
115
+ self.hop_length = _hop_length
116
  self.n_mels = n_mels
117
  self.f_min = f_min
118
+ self.f_max = f_max # Will be sampling_rate/2 if None in create_mel_filterbank call
119
 
120
  if self.win_length > self.n_fft:
121
  logger.warning(
122
  f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
123
+ "Window will be applied, then data will be zero-padded/truncated to n_fft by np.fft.rfft."
124
  )
125
+ self.window = np.hamming(self.win_length).astype(np.float32) # Or scipy.signal.get_window("hann", self.win_length)
126
  self.mel_filterbank = create_mel_filterbank(
127
  self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
128
+ ).T # Transpose for dot product: (n_fft // 2 + 1, n_mels)
129
+
130
 
131
  def __call__(
132
  self,
133
+ audios: Union[AudioInput, List[AudioInput]], # Accept single or list
134
+ sampling_rate: Optional[int] = None, # To specify SR if audios are raw arrays
135
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
136
  ) -> BatchFeature:
137
+
138
  if not isinstance(audios, list):
139
  audios = [audios]
140
 
141
+ processed_mels: List[torch.Tensor] = []
142
  actual_mel_lengths: List[int] = []
143
+
144
+ # Kept from user's code - their purpose might be for token calculation downstream
145
+ sizes_for_embed_length: List[torch.Tensor] = []
146
+ frames_scaled_by_feat_stride: List[int] = []
147
 
148
+ for audio_item in audios:
149
+ current_wav: np.ndarray
150
  source_sr: int
151
 
152
+ if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
153
+ current_wav, source_sr = audio_item
154
+ current_wav = np.asarray(current_wav, dtype=np.float32) # Ensure float32 numpy array
155
+ elif isinstance(audio_item, (np.ndarray, list)):
156
+ current_wav = np.asarray(audio_item, dtype=np.float32)
157
  if sampling_rate is None:
158
  raise ValueError(
159
+ "sampling_rate must be provided if audio inputs are raw numpy arrays or lists without sr."
160
  )
161
  source_sr = sampling_rate
162
+ # Add more robust loading for paths/bytes if transformers.audio_utils.load_audio is permissible
163
+ # Example:
164
+ # elif isinstance(audio_input, (str, bytes, Path)): # Path needs to be imported from pathlib
165
+ # current_wav, sr_dict = load_audio(audio_input_item) # Uses librosa or soundfile
166
+ # source_sr = sr_dict["sampling_rate"]
167
+ # current_wav = current_wav.astype(np.float32)
168
  else:
169
  raise TypeError(
170
+ f"Unsupported audio input type: {type(audio_item)}. "
171
+ "Expected np.ndarray, list of floats, or Tuple[np.ndarray, int]."
172
  )
173
+
174
+ processed_wav_array = self._preprocess_audio(current_wav, source_sr)
175
+ mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav_array) # Shape: (T_mel, N_Mels)
176
+
177
+ feature_tensor = torch.from_numpy(mel_spectrogram) # Already float32
178
+ processed_mels.append(feature_tensor)
179
+ actual_mel_lengths.append(feature_tensor.shape[0]) # T_mel for this item
180
+
181
+ # User's original logic for 'sizes' and 'frames'
182
+ sizes_for_embed_length.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
183
+ frames_scaled_by_feat_stride.append(feature_tensor.shape[0] * self.feat_stride)
184
+
185
+ # Pad the mel spectrograms to form a batch
186
+ audio_embeds = pad_sequence(processed_mels, batch_first=True, padding_value=self.padding_value)
187
+ # audio_embeds shape: (Batch, Max_T_mel, N_Mels)
188
+
189
+ # Create attention mask corresponding to the actual lengths of mel spectrograms
190
+ max_t_mel_in_batch = audio_embeds.shape[1]
191
+ current_device = audio_embeds.device # Get device from padded tensor if using PyTorch tensors earlier
192
+
193
+ # Create attention mask directly based on actual_mel_lengths
194
+ attention_mask = torch.zeros(len(audios), max_t_mel_in_batch, dtype=torch.bool, device=current_device)
195
+ for i, length in enumerate(actual_mel_lengths):
196
+ attention_mask[i, :length] = True
197
+
198
  output_data = {
199
+ "audio_values": audio_embeds,
200
+ "audio_attention_mask": attention_mask # Correctly shaped mask for audio_values
201
  }
202
 
203
+ # Include user's 'sizes' if they are needed downstream
204
+ if sizes_for_embed_length:
205
+ output_data["audio_values_sizes"] = torch.stack(sizes_for_embed_length)
206
+ # Note: 'frames_scaled_by_feat_stride' is a list of ints, handle conversion if needed in BatchFeature
207
 
208
  return BatchFeature(data=output_data, tensor_type=return_tensors)
209
 
210
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
211
+ # Ensure wav is float32
212
  if wav.dtype not in [np.float32, np.float64]:
213
  if np.issubdtype(wav.dtype, np.integer):
214
+ max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0 # Avoid error on empty array
215
  wav = wav.astype(np.float32) / max_val
216
  else:
217
  wav = wav.astype(np.float32)
218
+ elif wav.dtype == np.float64:
219
+ wav = wav.astype(np.float32)
220
 
221
  if wav.ndim > 1:
222
+ wav = wav.mean(axis=0) # Convert to mono
223
+
224
  if source_sr != self.sampling_rate:
225
+ logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
226
+ # Calculate integer up/down factors for resample_poly
227
+ common_divisor = math.gcd(self.sampling_rate, source_sr)
228
+ up_factor = self.sampling_rate // common_divisor
229
+ down_factor = source_sr // common_divisor
230
+ if up_factor != down_factor : # Only if actual resampling is needed
231
  wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
232
+
233
+ # Normalize amplitude to roughly [-1, 1]
234
+ max_abs_val = np.abs(wav).max()
235
+ if max_abs_val > 1e-7: # Avoid division by zero or tiny numbers
236
+ wav = wav / max_abs_val
237
  return wav
238
 
239
  def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
240
  if len(wav) < self.win_length:
241
+ # Pad if audio is shorter than one window
242
  padding = self.win_length - len(wav)
243
  wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
244
 
245
+ # Calculate number of frames
246
+ # This calculation ensures at least one frame if len(wav) == self.win_length
247
+ if len(wav) >= self.win_length:
248
+ num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
249
+ else: # Should be covered by padding, but as safeguard
250
+ num_frames = 0
251
+
252
  if num_frames <= 0:
253
+ logger.warning(f"Audio is too short (length {len(wav)}) to produce any frames "
254
+ f"with win_length {self.win_length} and hop_length {self.hop_length}. "
255
+ "Returning empty mel spectrogram.")
256
  return np.zeros((0, self.n_mels), dtype=np.float32)
257
 
258
+ # Framing using stride_tricks
259
+ strides = wav.strides[0]
260
+ frames_view = np.lib.stride_tricks.as_strided(
261
  wav,
262
  shape=(num_frames, self.win_length),
263
+ strides=(strides * self.hop_length, strides),
264
  writeable=False
265
  )
266
+ frames_data = frames_view.copy() # Important: copy after as_strided if modifying
267
+
268
+ frames_data *= self.window # Apply window in-place on the copy
269
 
270
+ # Compute STFT (rfft for real inputs)
271
+ # n_fft determines zero-padding or truncation for FFT input from each frame
272
+ spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
273
+ power = np.abs(spectrum)**2
274
+
275
+ mel_spectrogram = np.dot(power, self.mel_filterbank) # (num_frames, n_mels)
276
+
277
+ # Clip and take log
278
+ mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None) # Use defined epsilon
279
  log_mel_spectrogram = np.log(mel_spectrogram)
280
+
281
  return log_mel_spectrogram.astype(np.float32)
282
 
283
  def _calculate_embed_length(self, frame_count: int) -> int:
284
+ # User's original function
285
  compressed = math.ceil(frame_count / self.compression_rate)
286
  return math.ceil(compressed / self.qformer_rate)
287
 
288
 
289
+ class Gemma3ImagesKwargs(ImagesKwargs): # User's definition
290
+ do_pan_and_scan: Optional[bool]
291
+ pan_and_scan_min_crop_size: Optional[int]
292
+ pan_and_scan_max_num_crops: Optional[int]
293
+ pan_and_scan_min_ratio_to_activate: Optional[float]
294
+ do_convert_rgb: Optional[bool]
295
+
296
+
297
+ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): # User's definition
298
  images_kwargs: Dict[str, Any]
299
  audio_kwargs: Dict[str, Any]
300
+ # Added text_kwargs as it's commonly part of such structures
301
+ text_kwargs: Optional[Dict[str, Any]] = None
302
  _defaults = {
303
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
304
  "images_kwargs": {},
 
308
 
309
  class Gemma3OmniProcessor(ProcessorMixin):
310
  attributes = ["image_processor", "audio_processor", "tokenizer"]
311
+ valid_kwargs = ["chat_template", "image_seq_length"] # From user's code
 
 
 
312
 
313
+ # --- FIXED CLASS ATTRIBUTES ---
314
+ image_processor_class = "AutoImageProcessor" # As in user's original code
315
+ audio_processor_class = Gemma3AudioFeatureExtractor # Corrected to custom class
316
+ tokenizer_class = "AutoTokenizer" # As in user's original code
317
 
318
  def __init__(
319
  self,
320
+ image_processor=None, # Allow None, superclass or from_pretrained handles loading via _class
321
+ audio_processor=None, # Allow None or instance
322
+ tokenizer=None, # Allow None or instance
323
  chat_template=None,
324
  image_seq_length: int = 256,
325
+ **kwargs
 
 
 
 
 
326
  ):
327
+ # The ProcessorMixin's __init__ will handle instantiating these if they are None,
328
+ # using the respective *_class attributes.
329
+ # If specific instances are passed, they will be used.
330
+
331
+ # Retaining user's specific logic for setting attributes if needed,
332
+ # though much of this might be handled by super() or better placed after super()
333
+ self.image_seq_length = image_seq_length
334
+
335
+ # These tokenizer-dependent attributes should be set *after* super().__init__
336
+ # ensures self.tokenizer is populated, or if tokenizer is passed directly.
337
+ # If tokenizer is None and loaded by super(), these need to be set post-super().
338
+ # Assuming tokenizer is passed as an instantiated object for this snippet for now.
339
+ if tokenizer is None:
340
+ # This is a basic placeholder; HF's from_pretrained mechanism is more robust for loading
341
+ # For now, we'll assume if tokenizer is None, super() handles it or it's an error later.
342
+ pass
343
+ else: # Tokenizer was provided
344
+ self.image_token_id = getattr(tokenizer, "image_token_id", None) # More robust with getattr
345
+ self.boi_token = getattr(tokenizer, "boi_token", "<|image|>") # Defaulting if not present
346
+ self.image_token = getattr(tokenizer, "image_token", "<|image|>")
347
+ self.eoi_token = getattr(tokenizer, "eoi_token", "") # Added eoi_token as it was used
348
+
349
+ self.audio_token = "<audio_soft_token>" # User's definition
350
+ # self.expected_audio_token_id = 262143 # User's reference
351
+ # The existence of this token should be ensured when the tokenizer is prepared/saved.
352
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
353
+ # if self.audio_token_id != self.expected_audio_token_id: # User's warning
354
+ # logger.warning(...)
355
+ if self.audio_token_id == tokenizer.unk_token_id:
356
+ logger.warning(f"Audio token '{self.audio_token}' not found in tokenizer, maps to UNK. Ensure it's added.")
357
+
358
+
359
+ self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token if hasattr(tokenizer, 'eoi_token') else ''}\n\n"
360
+
361
+
362
+ # These seem specific to this processor's logic for determining audio token sequence length
363
+ # It's better to initialize them here.
364
+ self.audio_prompt_compression_rate = kwargs.pop("audio_prompt_compression_rate", 8)
365
+ self.audio_prompt_qformer_rate = kwargs.pop("audio_prompt_qformer_rate", 1)
366
+ self.audio_prompt_feat_stride = kwargs.pop("audio_prompt_feat_stride", 1)
367
+
368
 
369
  super().__init__(
370
  image_processor=image_processor,
371
  audio_processor=audio_processor,
372
  tokenizer=tokenizer,
373
  chat_template=chat_template,
374
+ **kwargs # Pass remaining kwargs to super
375
  )
376
+
377
+ # If tokenizer was loaded by super(), set tokenizer-dependent attributes now
378
+ if not hasattr(self, 'image_token_id') and self.tokenizer is not None:
379
+ self.image_token_id = getattr(self.tokenizer, "image_token_id", self.tokenizer.unk_token_id if hasattr(self.tokenizer, "unk_token_id") else None)
380
+ self.boi_token = getattr(self.tokenizer, "boi_token", "<|image|>")
381
+ self.image_token = getattr(self.tokenizer, "image_token", "<|image|>")
382
+ self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
383
+ self.audio_token = "<audio_soft_token>"
384
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token)
385
+ if self.audio_token_id == self.tokenizer.unk_token_id:
386
+ logger.warning(f"Audio token '{self.audio_token}' not found in tokenizer (post-super), maps to UNK. Ensure it's added.")
387
+ self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}\n\n"
388
+
389
+
390
+ def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs_from_call):
391
+ # User's original _merge_kwargs logic
392
+ default_kwargs = {}
393
+ # Ensure ModelProcessorKwargs._defaults exists and is a dict
394
+ _defaults_attr = getattr(ModelProcessorKwargs, "_defaults", {})
395
+ if not isinstance(_defaults_attr, dict):
396
+ _defaults_attr = {}
397
+
398
+ for modality in _defaults_attr:
399
+ default_kwargs[modality] = _defaults_attr.get(modality, {}).copy()
400
+
401
+ for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
402
+ if modality_key_in_call in default_kwargs:
403
+ if isinstance(modality_kwargs_in_call, dict):
404
+ default_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
405
+ elif isinstance(modality_kwargs_in_call, dict): # New modality not in defaults
406
+ default_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
407
+
408
+
409
+ # Update defaults with tokenizer init kwargs (original logic)
410
+ for modality_key in default_kwargs: # Iterate over current keys in default_kwargs
411
+ modality_dict = default_kwargs[modality_key]
412
+ if isinstance(modality_dict, dict): # Ensure it's a dict before trying to access keys
413
+ for key_in_mod_dict in list(modality_dict.keys()): # Iterate over copy of keys
414
+ if key_in_mod_dict in tokenizer_init_kwargs:
415
+ value = (
416
+ getattr(self.tokenizer, key_in_mod_dict)
417
+ if hasattr(self.tokenizer, key_in_mod_dict)
418
+ else tokenizer_init_kwargs[key_in_mod_dict]
419
+ )
420
+ modality_dict[key_in_mod_dict] = value
421
+
422
+ # Ensure text_kwargs processing (original logic)
423
+ if "text_kwargs" not in default_kwargs: # Ensure text_kwargs exists
424
+ default_kwargs["text_kwargs"] = {}
425
+ default_kwargs["text_kwargs"]["truncation"] = default_kwargs["text_kwargs"].get("truncation", False)
426
+ default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
427
 
 
 
 
 
 
 
 
 
 
428
 
429
+ return default_kwargs
 
 
 
 
 
430
 
431
+ def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
432
+ # Using processor's own rates for this calculation
433
+ result = math.ceil((audio_mel_frames * self.audio_prompt_feat_stride) / self.audio_prompt_compression_rate)
434
+ return math.ceil(result / self.audio_prompt_qformer_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
  def __call__(
437
  self,
438
+ images=None,
439
+ text:Union[str, List[str]]=None, # text is optional but often primary
440
+ # videos=None, # Removed 'videos' as it's not handled
441
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
442
+ sampling_rate: Optional[int] = None, # For audio_processor if audios are raw arrays
443
  return_tensors: Optional[Union[str, TensorType]] = None,
444
+ **kwargs: Any # Replaced Unpack for broader compatibility here
445
  ) -> BatchFeature:
446
+ if text is None and images is None and audios is None: # Added audios to check
 
447
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
448
 
449
  # Determine final return_tensors strategy
 
450
  final_rt = return_tensors
451
+ # Using Gemma3ProcessorKwargs as the class that holds _defaults structure
452
+ # This call to _merge_kwargs primarily populates kwargs for each modality if passed in __call__
453
+ # e.g. if user calls proc(..., text_kwargs={...})
454
  merged_call_kwargs = self._merge_kwargs(
455
+ Gemma3ProcessorKwargs,
456
+ self.tokenizer.init_kwargs if hasattr(self.tokenizer, "init_kwargs") else {},
457
  **kwargs
458
  )
459
+
460
+ # If return_tensors wasn't passed to __call__, try to get it from merged text_kwargs
461
+ # and remove it from there to avoid passing it twice to tokenizer.
462
+ # Default to PYTORCH if still None.
463
+ if final_rt is None:
464
  final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
465
+ else:
466
  merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
467
 
468
+
469
+ # Standardize text input
470
+ if text is None: # If no text given, create dummy text based on other modalities
471
  num_samples = 0
472
  if images is not None:
473
+ _images_list = images if isinstance(images, list) and (not images or not isinstance(images[0], (int,float))) else [images]
 
474
  num_samples = len(_images_list)
475
  elif audios is not None:
476
  _audios_list = audios if isinstance(audios, list) else [audios]
477
  num_samples = len(_audios_list)
478
+ text = [""] * num_samples if num_samples > 0 else [""] # Fallback for safety
479
 
480
  if isinstance(text, str):
481
  text = [text]
482
+ elif not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
483
+ raise ValueError("Input text must be a string or list of strings")
484
 
485
+ # --- Image Processing ---
486
  image_features_dict = {}
487
  if images is not None and self.image_processor is not None:
488
+ batched_images = make_nested_list_of_images(images) # HF utility
489
+ # Assuming image_processor returns a dict or BatchFeature. If BatchFeature, get .data
490
+ _img_proc_output = self.image_processor(batched_images, return_tensors=None, **merged_call_kwargs.get("images_kwargs", {}))
491
+ image_features_dict = _img_proc_output.data if isinstance(_img_proc_output, BatchFeature) else _img_proc_output
492
+
493
+
494
+ if len(batched_images) != len(text): # Validate batch consistency
495
+ raise ValueError(f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts")
496
+
497
+ # User's original image token replacement logic (complex, depends on num_crops etc from image_processor output)
498
+ # This part needs to be carefully adapted based on actual image_processor output structure
499
+ # For now, a simplified placeholder for the concept:
500
+ if "num_crops" in image_features_dict: # Example check
501
+ num_crops_list = to_py_obj(image_features_dict.pop("num_crops"))
502
+ # ... user's original logic for text modification with self.full_image_sequence ...
503
+ # This was: text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
504
+ # Need to adapt it if multiple images/crops per text sample.
505
+ # For simplicity, assuming one image sequence per text for now if an image is present.
506
+ temp_text = []
507
+ for i, prompt in enumerate(text):
508
+ if i < len(batched_images): # if this text sample has corresponding images
509
+ # Replace first boi_token or append if not found
510
+ if self.boi_token in prompt:
511
+ temp_text.append(prompt.replace(self.boi_token, self.full_image_sequence, 1))
512
+ else:
513
+ temp_text.append(prompt + self.full_image_sequence)
514
+ else:
515
+ temp_text.append(prompt)
516
+ text = temp_text
517
+
518
+
519
+ # --- Audio Processing ---
520
  audio_features_dict = {}
521
  if audios is not None and self.audio_processor is not None:
 
522
  audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
523
+ if sampling_rate is not None:
524
+ audio_call_kwargs["sampling_rate"] = sampling_rate
525
+
526
+ # audio_processor.__call__ returns BatchFeature, get its .data attribute for the dict
527
+ _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
528
+ audio_features_dict = _audio_proc_output.data
529
 
530
+ # Modify text to include audio soft tokens based on actual mel lengths
531
+ new_text_with_audio_tokens = []
532
+ # audio_attention_mask is (B, Max_T_mel)
533
+ actual_mel_frames_per_sample = to_py_obj(audio_features_dict["audio_attention_mask"].sum(axis=-1))
534
 
535
+ if len(actual_mel_frames_per_sample) != len(text):
536
+ raise ValueError(f"Inconsistent batch sizes for audio and text: {len(actual_mel_frames_per_sample)} audio samples, {len(text)} texts.")
 
537
 
538
  for i, prompt in enumerate(text):
539
+ num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
540
+ audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens # Repeat soft token string
541
+
542
+ # Replace a placeholder or append
543
+ placeholder = getattr(self, "audio_placeholder_token", "<|audio|>") # Use defined placeholder
544
+ if placeholder in prompt:
545
+ prompt_with_audio = prompt.replace(placeholder, audio_token_sequence_str, 1)
546
+ else:
547
+ prompt_with_audio = prompt + audio_token_sequence_str
548
+ new_text_with_audio_tokens.append(prompt_with_audio)
549
+ text = new_text_with_audio_tokens
550
+
551
+ # --- Text Tokenization ---
552
+ text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
553
+ # Tokenize the (potentially modified) text, request lists/np arrays
554
+ text_features_dict = self.tokenizer(text=text, return_tensors=None, **text_tokenizer_kwargs)
555
+
556
+ # Create token_type_ids
557
+ input_ids_list_of_lists = text_features_dict["input_ids"]
558
+ # Ensure it's a list of lists
559
+ if not (isinstance(input_ids_list_of_lists, list) and \
560
+ input_ids_list_of_lists and \
561
+ isinstance(input_ids_list_of_lists[0], list)):
562
+ if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
563
+ input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)
564
+ elif isinstance(input_ids_list_of_lists, list) and \
565
+ (not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
566
+ input_ids_list_of_lists = [input_ids_list_of_lists] # Batch of 1
567
+
568
+ mm_token_type_ids_list = []
569
+ for ids_sample in input_ids_list_of_lists:
570
+ type_ids_sample = [0] * len(ids_sample) # Default type 0 (text)
571
+ for idx, token_id_val in enumerate(ids_sample):
572
+ if self.image_token_id is not None and token_id_val == self.image_token_id:
573
+ type_ids_sample[idx] = 1 # Image token type
574
+ elif token_id_val == self.audio_token_id: # Compare with ID of <audio_soft_token>
575
+ type_ids_sample[idx] = 2 # Audio token type
576
+ mm_token_type_ids_list.append(type_ids_sample)
577
+ text_features_dict["token_type_ids"] = mm_token_type_ids_list
578
+
579
+ # Combine all features
580
+ final_batch_data = {**text_features_dict}
581
+ if image_features_dict:
582
+ final_batch_data.update(image_features_dict)
583
+ if audio_features_dict:
584
+ final_batch_data.update(audio_features_dict)
585
+
586
+ return BatchFeature(data=final_batch_data, tensor_type=final_rt) # Use determined final_rt
587
 
588
  def batch_decode(self, *args, **kwargs):
589
  return self.tokenizer.batch_decode(*args, **kwargs)
 
592
  return self.tokenizer.decode(*args, **kwargs)
593
 
594
  @property
595
+ def model_input_names(self):
596
+ tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
597
+ image_processor_inputs = []
598
+ if self.image_processor is not None: # Check if image_processor exists
599
+ image_processor_inputs = self.image_processor.model_input_names
600
+
601
+ audio_processor_inputs = []
602
+ if self.audio_processor is not None: # Check if audio_processor exists
603
+ # These are the keys Gemma3AudioFeatureExtractor puts in its output BatchFeature.data
604
+ audio_processor_inputs = ["audio_values", "audio_attention_mask"]
605
+ # "audio_values_sizes" was in user's original Gemma3AudioFeatureExtractor output,
606
+ # I renamed it to "audio_token_calc_sizes" for clarity; if it's a model input, add it back.
607
+
608
+ return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs + audio_processor_inputs))