voidful commited on
Commit
eaad0f5
·
verified ·
1 Parent(s): 5fc5a97

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +444 -214
processing_gemma3_omni.py CHANGED
@@ -6,11 +6,11 @@ import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
- from transformers.audio_utils import AudioInput
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
- from transformers.image_utils import make_nested_list_of_images
13
- from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, Unpack
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
16
  # Constants
@@ -19,12 +19,13 @@ DEFAULT_N_FFT = 512
19
  DEFAULT_WIN_LENGTH = 400
20
  DEFAULT_HOP_LENGTH = 160
21
  DEFAULT_N_MELS = 80
22
- DEFAULT_COMPRESSION_RATE = 4
23
- DEFAULT_QFORMER_RATE = 2
24
- DEFAULT_FEAT_STRIDE = 4
25
- IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
- AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
- DEFAULT_MAX_LENGTH = 16384
 
28
 
29
  logger = logging.get_logger(__name__)
30
 
@@ -32,25 +33,48 @@ logger = logging.get_logger(__name__)
32
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
33
  fmax: Optional[float] = None) -> np.ndarray:
34
  """Create Mel filterbank for audio processing."""
35
- fmax = fmax or sampling_rate / 2
36
 
37
- def hz_to_mel(f: float) -> float:
38
  return 1127.0 * math.log(1 + f / 700.0)
39
 
 
 
 
40
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
41
- freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
42
- bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
43
 
44
- filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
45
- for m in range(1, n_mels + 1):
46
- left, center, right = bins[m - 1:m + 2]
47
- filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left)
48
- filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return filterbank
51
 
52
 
53
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
 
 
54
  def __init__(
55
  self,
56
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
@@ -58,89 +82,191 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
58
  feat_stride: int = DEFAULT_FEAT_STRIDE,
59
  sampling_rate: int = DEFAULT_SAMPLING_RATE,
60
  n_fft: int = DEFAULT_N_FFT,
61
- win_length: int = DEFAULT_WIN_LENGTH,
62
- hop_length: int = DEFAULT_HOP_LENGTH,
63
  n_mels: int = DEFAULT_N_MELS,
 
 
 
64
  **kwargs
65
  ):
66
- kwargs.pop("feature_size", None)
67
- kwargs.pop("sampling_rate", None)
68
- kwargs.pop("padding_value", None)
69
 
70
  super().__init__(
71
- feature_size=n_mels,
72
- sampling_rate=sampling_rate,
73
- padding_value=0.0,
74
  **kwargs
75
  )
76
 
77
  self.compression_rate = compression_rate
78
  self.qformer_rate = qformer_rate
79
  self.feat_stride = feat_stride
80
- self.sampling_rate = sampling_rate
81
 
82
- self.window = np.hamming(win_length).astype(np.float32)
83
- self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T
84
  self.n_fft = n_fft
85
- self.hop_length = hop_length
86
- self.win_length = win_length
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def __call__(
89
  self,
90
- audios: List[AudioInput],
 
91
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
92
  ) -> BatchFeature:
93
- features, sizes, frames = [], [], []
94
 
95
- for wav in audios:
96
- processed_wav = self._preprocess_audio(wav, 22500)
97
- mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
98
- feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32)
99
- features.append(feature_tensor)
100
- sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
101
- frames.append(feature_tensor.shape[0] * self.feat_stride)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- audio_embeds = pad_sequence(features, batch_first=True)
104
- size_tensor = torch.stack(sizes)
 
105
 
106
- attention_mask = None
107
- if len(audios) > 1:
108
- frame_lengths = torch.tensor(frames)
109
- attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
110
 
111
  output_data = {
112
- "audio_values": audio_embeds,
113
- "audio_values_sizes": size_tensor
114
  }
115
- if attention_mask is not None:
116
- output_data["audio_attention_mask"] = attention_mask
117
 
 
 
 
 
 
118
  return BatchFeature(data=output_data, tensor_type=return_tensors)
119
 
120
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
121
- wav = torch.as_tensor(wav).float().numpy()
 
 
 
 
 
 
 
 
122
  if wav.ndim > 1:
123
  wav = wav.mean(axis=0)
 
124
  if source_sr != self.sampling_rate:
125
- wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr)
126
- return wav / max(np.abs(wav).max(), 1e-6)
 
 
 
 
 
 
 
 
 
127
 
128
  def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
129
- frame_count = 1 + (len(wav) - self.win_length) // self.hop_length
130
- strides = wav.strides[0]
131
- frames = np.lib.stride_tricks.as_strided(
 
 
 
 
 
 
 
 
 
 
 
132
  wav,
133
- shape=(frame_count, self.win_length),
134
- strides=(strides * self.hop_length, strides),
135
  writeable=False
136
- ).copy()
137
- frames *= self.window
 
138
 
139
- spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64)
140
  power = np.abs(spectrum) ** 2
141
  mel_spectrogram = np.dot(power, self.mel_filterbank)
142
- mel_spectrogram = np.clip(mel_spectrogram, 1.0, None)
143
- return np.log(mel_spectrogram, dtype=np.float32)
 
 
144
 
145
  def _calculate_embed_length(self, frame_count: int) -> int:
146
  compressed = math.ceil(frame_count / self.compression_rate)
@@ -156,8 +282,9 @@ class Gemma3ImagesKwargs(ImagesKwargs):
156
 
157
 
158
  class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
159
- images_kwargs: Dict[str, Any]
160
- audio_kwargs: Dict[str, Any]
 
161
  _defaults = {
162
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
163
  "images_kwargs": {},
@@ -168,38 +295,23 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
168
  class Gemma3OmniProcessor(ProcessorMixin):
169
  attributes = ["image_processor", "audio_processor", "tokenizer"]
170
  valid_kwargs = ["chat_template", "image_seq_length"]
 
171
  image_processor_class = "AutoImageProcessor"
172
- audio_processor_class = "AutoFeatureExtractor"
173
  tokenizer_class = "AutoTokenizer"
174
 
175
  def __init__(
176
  self,
177
- image_processor,
178
- audio_processor,
179
- tokenizer,
180
  chat_template=None,
181
  image_seq_length: int = 256,
182
  **kwargs
183
  ):
184
- self.image_seq_length = image_seq_length
185
- self.image_token_id = tokenizer.image_token_id
186
- self.boi_token = tokenizer.boi_token
187
- self.image_token = tokenizer.image_token
188
- self.audio_token = "<audio_soft_token>"
189
- self.expected_audio_token_id = 262143
190
- self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n"
191
-
192
- self.compression_rate = 8
193
- self.qformer_compression_rate = 1
194
- self.feat_stride = 1
195
-
196
- self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
197
- if self.audio_token_id != self.expected_audio_token_id:
198
- logger.warning(
199
- f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. "
200
- "Using assigned ID. Model embedding layer may need resizing."
201
- )
202
-
203
  super().__init__(
204
  image_processor=image_processor,
205
  audio_processor=audio_processor,
@@ -208,136 +320,243 @@ class Gemma3OmniProcessor(ProcessorMixin):
208
  **kwargs
209
  )
210
 
211
- def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs):
212
- default_kwargs = {}
213
- for modality in ModelProcessorKwargs._defaults:
214
- default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
215
-
216
- # Update defaults with tokenizer init kwargs
217
- for modality in default_kwargs:
218
- modality_kwargs = default_kwargs[modality]
219
- for key in modality_kwargs:
220
- if key in tokenizer_init_kwargs:
221
- value = (
222
- getattr(self.tokenizer, key)
223
- if hasattr(self.tokenizer, key)
224
- else tokenizer_init_kwargs[key]
225
- )
226
- modality_kwargs[key] = value
227
-
228
- # Update with user-provided kwargs
229
- for modality in default_kwargs:
230
- if modality in kwargs:
231
- default_kwargs[modality].update(kwargs[modality])
232
-
233
- # Ensure text_kwargs has truncation=False and large max_length
234
- default_kwargs["text_kwargs"]["truncation"] = False
235
- default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length",
236
- DEFAULT_MAX_LENGTH)
237
-
238
- return default_kwargs
239
-
240
- def _compute_audio_embed_size(self, audio_frames: int) -> int:
241
- result = math.ceil(audio_frames / self.compression_rate)
242
- return math.ceil(result / self.qformer_compression_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  def __call__(
245
  self,
246
- images=None,
247
- text=None,
248
- videos=None,
249
- audio=None,
250
- **kwargs: Unpack[Gemma3ProcessorKwargs]
 
251
  ) -> BatchFeature:
252
- if text is None and images is None:
253
- raise ValueError("Provide at least one of `text` or `images`.")
254
-
255
- output_kwargs = self._merge_kwargs(
256
- Gemma3ProcessorKwargs,
257
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
 
 
 
258
  **kwargs
259
  )
260
 
261
- if isinstance(text, str):
262
- text = [text]
263
- elif not isinstance(text, list) or not all(isinstance(t, str) for t in text):
264
- raise ValueError("Input text must be a string or list of strings")
265
-
266
- return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
267
- image_inputs = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  if images is not None:
 
269
  batched_images = make_nested_list_of_images(images)
270
- image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
271
-
272
- if not text:
273
- text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
274
-
275
- if len(batched_images) != len(text):
276
- raise ValueError(
277
- f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts"
278
- )
279
-
280
- num_crops = to_py_obj(image_inputs.pop("num_crops"))
281
- batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
282
-
283
- for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)):
284
- image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
285
- if len(images) != len(image_indexes):
286
- raise ValueError(
287
- f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images"
288
- )
289
-
290
- for num, idx in reversed(list(zip(crops, image_indexes))):
291
- if num:
292
- formatted_image_text = (
293
- f"Here is the original image {self.boi_token} and here are some crops to help you see better "
294
- + " ".join([self.boi_token] * num)
295
- )
296
- prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):]
297
- text[batch_idx] = prompt
298
-
299
- text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
300
-
301
- audio_inputs = {}
302
- if audio is not None:
303
- audio_inputs = self.audio_processor(audio, return_tensors)
304
- audio_embeds = audio_inputs['audio_values']
305
- audio_frames = audio_embeds.shape[1] * self.feat_stride
306
- audio_seq_length = self._compute_audio_embed_size(audio_frames)
307
-
308
- audio_tokens = {
309
- "boa_token": "<start_of_audio>",
310
- "eoa_token": "<end_of_audio>",
311
- "audio_token": "<audio_soft_token>",
312
- "boa_token_id": 256001,
313
- "eoa_token_id": 256002,
314
- "audio_token_id": self.audio_token_id # Use dynamic ID
315
- }
316
-
317
- audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n"
318
- text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text]
319
-
320
- text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors=return_tensors)
321
-
322
- # Debug: Log text and token counts before validation
323
- for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])):
324
- audio_text_count = txt.count(self.audio_token)
325
- audio_ids_count = list(ids).count(self.audio_token_id)
326
- logger.debug(
327
- f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, "
328
- f"Text length={len(txt)}, Input IDs length={len(ids)}"
329
- )
330
-
331
- array_ids = text_inputs["input_ids"]
332
- if return_tensors == "pt":
333
- mm_token_type_ids = torch.zeros_like(array_ids)
334
- else:
335
- mm_token_type_ids = np.zeros_like(array_ids)
336
- mm_token_type_ids[array_ids == self.image_token_id] = 1 # Image token type
337
- mm_token_type_ids[array_ids == self.audio_token_id] = 2 # Audio token type
338
- text_inputs["token_type_ids"] = mm_token_type_ids
339
-
340
- return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
  def batch_decode(self, *args, **kwargs):
343
  return self.tokenizer.batch_decode(*args, **kwargs)
@@ -346,8 +565,19 @@ class Gemma3OmniProcessor(ProcessorMixin):
346
  return self.tokenizer.decode(*args, **kwargs)
347
 
348
  @property
349
- def model_input_names(self):
350
- tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
351
- image_processor_inputs = self.image_processor.model_input_names
352
- audio_processor_inputs = self.audio_processor.model_input_names
353
- return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs + audio_processor_inputs))
 
 
 
 
 
 
 
 
 
 
 
 
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
+ from transformers.audio_utils import AudioInput # type: ignore
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
+ from transformers.image_utils import make_nested_list_of_images # If image processing is used
13
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
16
  # Constants
 
19
  DEFAULT_WIN_LENGTH = 400
20
  DEFAULT_HOP_LENGTH = 160
21
  DEFAULT_N_MELS = 80
22
+ DEFAULT_COMPRESSION_RATE = 4 # For _calculate_embed_length
23
+ DEFAULT_QFORMER_RATE = 2 # For _calculate_embed_length
24
+ DEFAULT_FEAT_STRIDE = 4 # For _calculate_embed_length / 'frames'
25
+ IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>" # Not used in this file directly
26
+ AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>" # Not used in this file directly
27
+ DEFAULT_MAX_LENGTH = 16384 # For tokenizer default
28
+ LOG_MEL_CLIP_EPSILON = 1e-5
29
 
30
  logger = logging.get_logger(__name__)
31
 
 
33
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
34
  fmax: Optional[float] = None) -> np.ndarray:
35
  """Create Mel filterbank for audio processing."""
36
+ fmax = fmax or sampling_rate / 2.0
37
 
38
+ def hz_to_mel(f: float) -> float: # User's formula
39
  return 1127.0 * math.log(1 + f / 700.0)
40
 
41
+ if fmin >= fmax:
42
+ raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
43
+
44
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
45
+ freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of user's hz_to_mel
 
46
 
47
+ freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
48
+ bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(
49
+ int) # (n_fft+1) or n_fft/2 ? Librosa uses n_fft//2 * hz / sr_nyquist
50
+ bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output is n_fft//2
 
51
 
52
+ filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
53
+ for m_idx in range(n_mels):
54
+ left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
55
+
56
+ if center > left: # Rising slope
57
+ filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
58
+ if right > center: # Falling slope
59
+ # Need to ensure the peak is 1 if center was part of rising slope
60
+ # If left==center, this part creates the full triangle (rising is skipped)
61
+ filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
62
+
63
+ # Ensure the peak at 'center' is 1.0 if it's a valid point.
64
+ # This handles cases where left=center or center=right if the slopes don't perfectly set it.
65
+ if left <= center <= right:
66
+ if filterbank.shape[1] > center: # Check bounds for center index
67
+ if (center > left and filterbank[m_idx, center] < 1.0) or \
68
+ (center < right and filterbank[m_idx, center] < 1.0) or \
69
+ (left == center and center < right) or \
70
+ (right == center and left < center):
71
+ filterbank[m_idx, center] = 1.0
72
  return filterbank
73
 
74
 
75
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
76
+ model_input_names = ["audio_values", "audio_attention_mask"]
77
+
78
  def __init__(
79
  self,
80
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
 
82
  feat_stride: int = DEFAULT_FEAT_STRIDE,
83
  sampling_rate: int = DEFAULT_SAMPLING_RATE,
84
  n_fft: int = DEFAULT_N_FFT,
85
+ win_length: Optional[int] = None,
86
+ hop_length: Optional[int] = None,
87
  n_mels: int = DEFAULT_N_MELS,
88
+ f_min: float = 0.0,
89
+ f_max: Optional[float] = None,
90
+ padding_value: float = 0.0,
91
  **kwargs
92
  ):
93
+ _win_length = win_length if win_length is not None else n_fft
94
+ _hop_length = hop_length if hop_length is not None else _win_length // 4
 
95
 
96
  super().__init__(
97
+ feature_size=n_mels, # This is num_mel_bins
98
+ sampling_rate=sampling_rate, # This is the target sampling rate for featurization
99
+ padding_value=padding_value,
100
  **kwargs
101
  )
102
 
103
  self.compression_rate = compression_rate
104
  self.qformer_rate = qformer_rate
105
  self.feat_stride = feat_stride
106
+ # self.sampling_rate is set by super() to the target rate
107
 
 
 
108
  self.n_fft = n_fft
109
+ self.win_length = _win_length
110
+ self.hop_length = _hop_length
111
+ self.n_mels = n_mels
112
+ self.f_min = f_min
113
+ self.f_max = f_max if f_max is not None else self.sampling_rate / 2.0
114
+
115
+ if self.win_length > self.n_fft:
116
+ logger.warning(
117
+ f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
118
+ "Window will be applied, then data zero-padded/truncated to n_fft by np.fft.rfft."
119
+ )
120
+ self.window = np.hamming(self.win_length).astype(np.float32)
121
+ self.mel_filterbank = create_mel_filterbank(
122
+ self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
123
+ ).T
124
 
125
  def __call__(
126
  self,
127
+ audios: Union[AudioInput, List[AudioInput]],
128
+ sampling_rate: Optional[int] = None, # SR of input raw audio arrays
129
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
130
  ) -> BatchFeature:
 
131
 
132
+ if not isinstance(audios, list):
133
+ audios = [audios]
134
+
135
+ processed_mels: List[torch.Tensor] = []
136
+ actual_mel_lengths: List[int] = []
137
+
138
+ # These lists are from your original code; their values might be used by Gemma3OmniProcessor later.
139
+ sizes_for_downstream_calc: List[torch.Tensor] = []
140
+ frames_scaled_for_downstream_calc: List[int] = []
141
+
142
+ for audio_item in audios:
143
+ current_wav_array: np.ndarray
144
+ source_sr: int # Original sampling rate of the current_wav_array
145
+
146
+ if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
147
+ current_wav_array, source_sr = audio_item
148
+ current_wav_array = np.asarray(current_wav_array, dtype=np.float32)
149
+ elif isinstance(audio_item, (np.ndarray, list)): # Raw waveform as array/list
150
+ current_wav_array = np.asarray(audio_item, dtype=np.float32)
151
+ if sampling_rate is None:
152
+ raise ValueError(
153
+ "sampling_rate argument must be provided to __call__ if 'audios' items "
154
+ "are raw numpy arrays or lists (without embedded sampling rate info)."
155
+ )
156
+ source_sr = sampling_rate
157
+ else:
158
+ # If you expect to load from paths/bytes, you'd use transformers.audio_utils.load_audio here
159
+ raise TypeError(
160
+ f"Unsupported audio_item type: {type(audio_item)}. Expected np.ndarray, list of floats, "
161
+ "or Tuple[np.ndarray, int (sampling_rate)]."
162
+ )
163
+
164
+ logger.debug(
165
+ f"Gemma3AudioFeatureExtractor: Processing audio item with original shape {current_wav_array.shape}, source_sr {source_sr}")
166
+
167
+ # 1. Preprocess: convert to mono, resample to self.sampling_rate, normalize
168
+ processed_wav_for_mel = self._preprocess_audio(current_wav_array, source_sr)
169
+
170
+ # 2. Compute Log-Mel Spectrogram: results in (NumFrames, self.n_mels)
171
+ mel_spectrogram_np = self._compute_log_mel_spectrogram(processed_wav_for_mel)
172
+ logger.debug(f"Gemma3AudioFeatureExtractor: Computed mel_spectrogram shape: {mel_spectrogram_np.shape}")
173
+
174
+ if not (mel_spectrogram_np.ndim == 2 and mel_spectrogram_np.shape[1] == self.n_mels):
175
+ # This check is important if _compute_log_mel_spectrogram could return variable shapes
176
+ logger.error(
177
+ f"Mel spectrogram computation resulted in unexpected shape {mel_spectrogram_np.shape}. Expected (NumFrames, {self.n_mels})")
178
+ # Fallback to a zero-feature tensor of correct feature dimension but zero time, or handle error
179
+ # This indicates a problem in _compute_log_mel_spectrogram or very unusual input
180
+ # For now, let it proceed, but this would be an issue.
181
+ # If num_frames was 0, shape would be (0, n_mels), which is valid.
182
 
183
+ feature_tensor = torch.from_numpy(mel_spectrogram_np) # Already float32
184
+ processed_mels.append(feature_tensor)
185
+ actual_mel_lengths.append(feature_tensor.shape[0]) # Number of time frames
186
 
187
+ # Original logic for 'sizes' and 'frames' (kept for compatibility with your processor)
188
+ sizes_for_downstream_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
189
+ frames_scaled_for_downstream_calc.append(feature_tensor.shape[0] * self.feat_stride)
190
+
191
+ # Pad the list of 2D Mel spectrograms to form a 3D batch
192
+ # Output shape: (Batch, MaxNumFrames, NumMels)
193
+ audio_values_batched = pad_sequence(processed_mels, batch_first=True, padding_value=self.padding_value)
194
+
195
+ # Create attention mask for the padded batch
196
+ max_t_mel_in_batch = audio_values_batched.shape[1]
197
+
198
+ attention_mask_batched = torch.zeros(len(audios), max_t_mel_in_batch, dtype=torch.bool)
199
+ for i, length in enumerate(actual_mel_lengths):
200
+ attention_mask_batched[i, :length] = True
201
 
202
  output_data = {
203
+ "audio_values": audio_values_batched, # Expected by model as (B, T, F)
204
+ "audio_attention_mask": attention_mask_batched # Mask for "audio_values"
205
  }
 
 
206
 
207
+ if sizes_for_downstream_calc: # If these are used by the OmniProcessor
208
+ output_data["audio_values_sizes"] = torch.stack(sizes_for_downstream_calc)
209
+
210
+ logger.info(
211
+ f"Gemma3AudioFeatureExtractor: Final 'audio_values' batch shape: {output_data['audio_values'].shape}")
212
  return BatchFeature(data=output_data, tensor_type=return_tensors)
213
 
214
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
215
+ if wav.dtype not in [np.float32, np.float64]:
216
+ if np.issubdtype(wav.dtype, np.integer):
217
+ max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0
218
+ wav = wav.astype(np.float32) / max_val
219
+ else:
220
+ wav = wav.astype(np.float32)
221
+ elif wav.dtype == np.float64:
222
+ wav = wav.astype(np.float32)
223
+
224
  if wav.ndim > 1:
225
  wav = wav.mean(axis=0)
226
+
227
  if source_sr != self.sampling_rate:
228
+ # logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
229
+ common_divisor = math.gcd(self.sampling_rate, source_sr)
230
+ up_factor = self.sampling_rate // common_divisor
231
+ down_factor = source_sr // common_divisor
232
+ if up_factor != down_factor:
233
+ wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
234
+
235
+ max_abs_val = np.abs(wav).max()
236
+ if max_abs_val > 1e-7: # Avoid division by zero/small numbers for silent/near-silent audio
237
+ wav = wav / max_abs_val
238
+ return wav
239
 
240
  def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
241
+ if len(wav) < self.win_length:
242
+ padding = self.win_length - len(wav)
243
+ wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
244
+
245
+ if len(wav) >= self.win_length:
246
+ num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
247
+ else:
248
+ num_frames = 0 # Should be caught by the padding above, but defensive.
249
+
250
+ if num_frames <= 0:
251
+ # logger.warning(...)
252
+ return np.zeros((0, self.n_mels), dtype=np.float32) # Return shape (0, N_Mels)
253
+
254
+ frames_view = np.lib.stride_tricks.as_strided(
255
  wav,
256
+ shape=(num_frames, self.win_length),
257
+ strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
258
  writeable=False
259
+ )
260
+ frames_data = frames_view.copy() # Ensure it's a copy before in-place modification
261
+ frames_data *= self.window
262
 
263
+ spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
264
  power = np.abs(spectrum) ** 2
265
  mel_spectrogram = np.dot(power, self.mel_filterbank)
266
+ mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
267
+ log_mel_spectrogram = np.log(mel_spectrogram)
268
+
269
+ return log_mel_spectrogram.astype(np.float32)
270
 
271
  def _calculate_embed_length(self, frame_count: int) -> int:
272
  compressed = math.ceil(frame_count / self.compression_rate)
 
282
 
283
 
284
  class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
285
+ images_kwargs: Optional[Dict[str, Any]] = None
286
+ audio_kwargs: Optional[Dict[str, Any]] = None
287
+ text_kwargs: Optional[Dict[str, Any]] = None
288
  _defaults = {
289
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
290
  "images_kwargs": {},
 
295
  class Gemma3OmniProcessor(ProcessorMixin):
296
  attributes = ["image_processor", "audio_processor", "tokenizer"]
297
  valid_kwargs = ["chat_template", "image_seq_length"]
298
+
299
  image_processor_class = "AutoImageProcessor"
300
+ audio_processor_class = "AutoFeatureExtractor" # CRITICAL: Must be string name of your custom class
301
  tokenizer_class = "AutoTokenizer"
302
 
303
  def __init__(
304
  self,
305
+ image_processor=None,
306
+ audio_processor=None,
307
+ tokenizer=None,
308
  chat_template=None,
309
  image_seq_length: int = 256,
310
  **kwargs
311
  ):
312
+ # ProcessorMixin.__init__ handles instantiation of audio_processor, image_processor, tokenizer
313
+ # if they are None when passed to it, using the *_class attributes defined above.
314
+ # If actual instances are passed (e.g., from from_pretrained), they will be used.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  super().__init__(
316
  image_processor=image_processor,
317
  audio_processor=audio_processor,
 
320
  **kwargs
321
  )
322
 
323
+ # These attributes depend on self.tokenizer being properly initialized by super()
324
+ self.image_seq_length = image_seq_length
325
+ if self.tokenizer is not None:
326
+ # Use getattr for robustness, providing defaults if attributes are missing
327
+ self.image_token_id = getattr(self.tokenizer, "image_token_id",
328
+ self.tokenizer.unk_token_id if hasattr(self.tokenizer,
329
+ "unk_token_id") else None)
330
+ self.boi_token = getattr(self.tokenizer, "boi_token", "<image>") # More common default
331
+ self.image_token = getattr(self.tokenizer, "image_token", "<image>")
332
+ self.eoi_token = getattr(self.tokenizer, "eoi_token", "") # Default to empty if not present
333
+
334
+ # User's original attributes for audio tokens
335
+ self.audio_token_str_from_user_code = "<audio_soft_token>"
336
+ # self.expected_audio_token_id = 262143 # User's reference, keep commented for minimal change
337
+
338
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
339
+ if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
340
+ logger.warning(
341
+ f"The audio token string '{self.audio_token_str_from_user_code}' maps to the UNK token. "
342
+ "Please ensure it is added to the tokenizer's vocabulary as a special token."
343
+ )
344
+ self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token}\n\n"
345
+ else:
346
+ # This state (tokenizer is None after super init) should ideally not occur if from_pretrained works.
347
+ logger.error(
348
+ "Gemma3OmniProcessor initialized, but self.tokenizer is None. Token-dependent attributes will use placeholders or defaults.")
349
+ self.image_token_id = None
350
+ self.boi_token = "<image>"
351
+ self.image_token = "<image>"
352
+ self.eoi_token = ""
353
+ self.audio_token_str_from_user_code = "<audio_soft_token>"
354
+ self.audio_token_id = -1 # Placeholder
355
+ self.full_image_sequence = ""
356
+
357
+ # These are parameters for this processor's logic for number of audio tokens in prompt
358
+ self.prompt_audio_compression_rate = kwargs.pop("audio_prompt_compression_rate", 8)
359
+ self.prompt_audio_qformer_rate = kwargs.pop("audio_prompt_qformer_rate", 1)
360
+ self.prompt_audio_feat_stride = kwargs.pop("audio_prompt_feat_stride", 1)
361
+ self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")
362
+
363
+ def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
364
+ # This method merges default kwargs, tokenizer init kwargs, and call-specific kwargs
365
+ final_kwargs = {}
366
+ _defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
367
+ if not isinstance(_defaults, dict): _defaults = {}
368
+
369
+ for modality_key, default_modality_kwargs in _defaults.items():
370
+ final_kwargs[modality_key] = default_modality_kwargs.copy()
371
+
372
+ for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
373
+ if modality_key_in_call in final_kwargs: # e.g. "text_kwargs"
374
+ if isinstance(modality_kwargs_in_call, dict):
375
+ final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
376
+ elif isinstance(modality_kwargs_in_call, dict): # New modality not in _defaults (e.g. "video_kwargs")
377
+ final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
378
+
379
+ if self.tokenizer: # Ensure tokenizer is available for its init_kwargs
380
+ for modality_key in final_kwargs:
381
+ modality_dict = final_kwargs[modality_key]
382
+ if isinstance(modality_dict, dict):
383
+ for key_in_mod_dict in list(modality_dict.keys()):
384
+ if key_in_mod_dict in tokenizer_init_kwargs: # tokenizer_init_kwargs from self.tokenizer.init_kwargs
385
+ value = (
386
+ getattr(self.tokenizer, key_in_mod_dict) # Check actual tokenizer attribute first
387
+ if hasattr(self.tokenizer, key_in_mod_dict)
388
+ else tokenizer_init_kwargs[key_in_mod_dict]
389
+ )
390
+ modality_dict[key_in_mod_dict] = value
391
+
392
+ if "text_kwargs" not in final_kwargs:
393
+ final_kwargs["text_kwargs"] = {}
394
+ # Ensure these text_kwargs have defaults if not set otherwise
395
+ final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
396
+ final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
397
+
398
+ return final_kwargs
399
+
400
+ def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
401
+ scaled_frames = audio_mel_frames * self.prompt_audio_feat_stride
402
+ result = math.ceil(scaled_frames / self.prompt_audio_compression_rate)
403
+ return math.ceil(result / self.prompt_audio_qformer_rate)
404
 
405
  def __call__(
406
  self,
407
+ text: Union[str, List[str]] = None,
408
+ images: Optional[Any] = None,
409
+ audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
410
+ sampling_rate: Optional[int] = None,
411
+ return_tensors: Optional[Union[str, TensorType]] = None,
412
+ **kwargs: Any
413
  ) -> BatchFeature:
414
+ if text is None and images is None and audios is None:
415
+ raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
416
+
417
+ # Determine final return_tensors strategy (explicit __call__ arg > from text_kwargs > default)
418
+ final_rt = return_tensors
419
+ # _merge_kwargs uses Gemma3ProcessorKwargs to structure the **kwargs from __call__
420
+ merged_call_kwargs = self._merge_kwargs(
421
+ Gemma3ProcessorKwargs, # Class defining _defaults structure
422
+ self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
423
  **kwargs
424
  )
425
 
426
+ if final_rt is None: # If not passed directly to __call__
427
+ # Get from merged_call_kwargs (which would have picked it up from kwargs['text_kwargs'])
428
+ # and remove it to prevent passing twice to tokenizer
429
+ final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
430
+ else: # If passed directly, ensure it's removed from text_kwargs to avoid conflict
431
+ merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
432
+
433
+ if text is None: # Default text if only other modalities are provided
434
+ num_samples = 0
435
+ if images is not None:
436
+ _images_list = images if isinstance(images, list) and (
437
+ not images or not isinstance(images[0], (int, float))) else [images]
438
+ num_samples = len(_images_list)
439
+ elif audios is not None:
440
+ _audios_list = audios if isinstance(audios, list) else [audios]
441
+ num_samples = len(_audios_list)
442
+ text = [""] * num_samples if num_samples > 0 else [""] # Create empty strings or one if no samples
443
+
444
+ if isinstance(text, str): text = [text]
445
+ if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
446
+ raise ValueError("Input `text` must be a string or a list of strings.")
447
+
448
+ # --- Image Processing (User's structure) ---
449
+ image_features_dict = {}
450
  if images is not None:
451
+ if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
452
  batched_images = make_nested_list_of_images(images)
453
+ _img_proc_output = self.image_processor(batched_images, return_tensors=None,
454
+ **merged_call_kwargs.get("images_kwargs", {}))
455
+ image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
456
+ BatchFeature) else _img_proc_output
457
+
458
+ # Adjust text based on images (user's original logic)
459
+ if len(text) == 0 and len(batched_images) > 0: text = [" ".join([self.boi_token] * len(img_batch)) for
460
+ img_batch in batched_images]
461
+ if len(batched_images) != len(text): raise ValueError(
462
+ f"Inconsistent batch: {len(batched_images)} images, {len(text)} texts")
463
+
464
+ num_crops_popped = image_features_dict.pop("num_crops", None)
465
+ if num_crops_popped is not None:
466
+ num_crops_all = to_py_obj(num_crops_popped)
467
+ temp_text_img, current_crop_idx_offset = [], 0
468
+ for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
469
+ crops_for_this_batch_sample = []
470
+ if num_crops_all:
471
+ for _ in current_imgs_in_batch:
472
+ if current_crop_idx_offset < len(num_crops_all):
473
+ crops_for_this_batch_sample.append(
474
+ num_crops_all[current_crop_idx_offset]); current_crop_idx_offset += 1
475
+ else:
476
+ crops_for_this_batch_sample.append(0)
477
+ image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
478
+ processed_prompt = prompt
479
+ iter_count = min(len(crops_for_this_batch_sample), len(image_indexes))
480
+ for i_crop_idx in range(iter_count - 1, -1, -1):
481
+ num_additional_crops = crops_for_this_batch_sample[i_crop_idx]
482
+ original_token_idx = image_indexes[i_crop_idx]
483
+ if num_additional_crops > 0:
484
+ replacement_text = (
485
+ f"Here is the original image {self.boi_token} and here are some crops to help you see better " + " ".join(
486
+ [self.boi_token] * num_additional_crops))
487
+ processed_prompt = processed_prompt[
488
+ :original_token_idx] + replacement_text + processed_prompt[
489
+ original_token_idx + len(
490
+ self.boi_token):]
491
+ temp_text_img.append(processed_prompt)
492
+ text = temp_text_img
493
+ text = [p.replace(self.boi_token, self.full_image_sequence) for p in text]
494
+
495
+ # --- Audio Processing ---
496
+ audio_features_dict = {}
497
+ if audios is not None:
498
+ if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")
499
+ audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
500
+ if sampling_rate is not None: audio_call_kwargs[
501
+ "sampling_rate"] = sampling_rate # Pass SR to feature extractor
502
+
503
+ _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
504
+ audio_features_dict = _audio_proc_output.data
505
+ logger.info(
506
+ f"Gemma3OmniProcessor: 'audio_values' shape from Feature Extractor: {audio_features_dict['audio_values'].shape}")
507
+
508
+ new_text_with_audio, actual_mel_frames_per_sample = [], to_py_obj(
509
+ audio_features_dict["audio_attention_mask"].sum(axis=-1))
510
+ if len(actual_mel_frames_per_sample) != len(text): raise ValueError(
511
+ f"Inconsistent batch for audio/text: {len(actual_mel_frames_per_sample)} audio, {len(text)} text.")
512
+
513
+ for i, prompt in enumerate(text):
514
+ num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
515
+ audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens # e.g. "<audio_soft_token>" * N
516
+
517
+ # User's original boa_token for replacement was " ", which is risky. Using defined placeholder.
518
+ if self.audio_placeholder_token in prompt:
519
+ prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
520
+ else:
521
+ prompt += audio_token_sequence_str
522
+ new_text_with_audio.append(prompt)
523
+ text = new_text_with_audio
524
+
525
+ # --- Text Tokenization ---
526
+ text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
527
+ text_features_dict = self.tokenizer(text=text, return_tensors=None,
528
+ **text_tokenizer_kwargs) # Get lists/np.arrays
529
+
530
+ input_ids_list_of_lists = text_features_dict["input_ids"]
531
+ if not isinstance(input_ids_list_of_lists, list) or not (
532
+ input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)):
533
+ if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
534
+ input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)
535
+ elif isinstance(input_ids_list_of_lists, list) and (
536
+ not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
537
+ input_ids_list_of_lists = [input_ids_list_of_lists]
538
+
539
+ token_type_ids_list = []
540
+ for ids_sample in input_ids_list_of_lists:
541
+ types = [0] * len(ids_sample)
542
+ for j, token_id_val in enumerate(ids_sample):
543
+ if self.image_token_id is not None and token_id_val == self.image_token_id:
544
+ types[j] = 1
545
+ elif self.audio_token_id != -1 and token_id_val == self.audio_token_id:
546
+ types[j] = 2
547
+ token_type_ids_list.append(types)
548
+ text_features_dict["token_type_ids"] = token_type_ids_list
549
+
550
+ # Ensure text_features_dict also has 'attention_mask' if tokenizer applied padding
551
+ # If tokenizer was called with padding=True/strategy, it would add 'attention_mask'
552
+ # If called with padding=False (default), 'attention_mask' might be missing or all 1s.
553
+ # BatchFeature will handle final tensor conversion and padding based on final_rt.
554
+
555
+ final_batch_data = {**text_features_dict}
556
+ if image_features_dict: final_batch_data.update(image_features_dict)
557
+ if audio_features_dict: final_batch_data.update(audio_features_dict)
558
+
559
+ return BatchFeature(data=final_batch_data, tensor_type=final_rt)
560
 
561
  def batch_decode(self, *args, **kwargs):
562
  return self.tokenizer.batch_decode(*args, **kwargs)
 
565
  return self.tokenizer.decode(*args, **kwargs)
566
 
567
  @property
568
+ def model_input_names(self) -> List[str]:
569
+ input_names = set()
570
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
571
+ input_names.update(self.tokenizer.model_input_names + ["token_type_ids"])
572
+
573
+ if hasattr(self, 'image_processor') and self.image_processor is not None:
574
+ input_names.update(self.image_processor.model_input_names)
575
+
576
+ if hasattr(self, 'audio_processor') and self.audio_processor is not None and \
577
+ hasattr(self.audio_processor, 'model_input_names'):
578
+ input_names.update(self.audio_processor.model_input_names)
579
+ elif hasattr(self,
580
+ 'audio_processor') and self.audio_processor is not None: # Fallback if model_input_names not on custom audio_processor
581
+ input_names.update(["audio_values", "audio_attention_mask"])
582
+
583
+ return list(input_names)