Update processing_gemma3_omni.py
Browse files- processing_gemma3_omni.py +347 -250
processing_gemma3_omni.py
CHANGED
@@ -1,253 +1,263 @@
|
|
1 |
import re
|
2 |
-
from typing import List, Optional, Union, Dict, Any
|
3 |
|
4 |
import math
|
5 |
import numpy as np
|
6 |
import scipy.signal
|
7 |
import torch
|
8 |
from torch.nn.utils.rnn import pad_sequence
|
9 |
-
from transformers.audio_utils import AudioInput
|
10 |
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
11 |
from transformers.feature_extraction_utils import BatchFeature
|
12 |
-
from transformers.image_utils import make_nested_list_of_images
|
13 |
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
|
14 |
from transformers.utils import TensorType, to_py_obj, logging
|
15 |
|
16 |
# Constants
|
17 |
DEFAULT_SAMPLING_RATE = 16000
|
18 |
DEFAULT_N_FFT = 512
|
19 |
-
DEFAULT_WIN_LENGTH = 400
|
20 |
-
DEFAULT_HOP_LENGTH = 160
|
21 |
DEFAULT_N_MELS = 80
|
22 |
-
DEFAULT_COMPRESSION_RATE = 4
|
23 |
-
DEFAULT_QFORMER_RATE = 2
|
24 |
-
DEFAULT_FEAT_STRIDE = 4
|
25 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
26 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
27 |
DEFAULT_MAX_LENGTH = 16384
|
28 |
-
LOG_MEL_CLIP_EPSILON = 1e-5
|
29 |
|
30 |
logger = logging.get_logger(__name__)
|
31 |
|
32 |
|
|
|
|
|
33 |
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
|
34 |
fmax: Optional[float] = None) -> np.ndarray:
|
35 |
"""Create Mel filterbank for audio processing."""
|
36 |
fmax = fmax or sampling_rate / 2.0
|
37 |
|
38 |
-
def hz_to_mel(f: float) -> float:
|
39 |
return 1127.0 * math.log(1 + f / 700.0)
|
40 |
|
41 |
if fmin >= fmax:
|
42 |
raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
|
43 |
|
44 |
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
45 |
-
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
|
46 |
|
47 |
freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
|
48 |
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
|
49 |
-
bins = np.clip(bins, 0, n_fft // 2)
|
50 |
|
51 |
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
|
52 |
for m_idx in range(n_mels):
|
53 |
left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
|
54 |
|
55 |
-
if center > left:
|
56 |
filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
|
57 |
-
if right > center:
|
58 |
filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
|
59 |
|
60 |
# Ensure the peak at 'center' is 1.0 if it's a valid point.
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
filterbank[m_idx, center] = 1.0
|
|
|
|
|
|
|
68 |
return filterbank
|
69 |
|
70 |
|
|
|
71 |
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
72 |
-
model_input_names = ["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
def
|
75 |
self,
|
76 |
-
|
77 |
-
|
78 |
-
feat_stride: int = DEFAULT_FEAT_STRIDE,
|
79 |
-
sampling_rate: int = DEFAULT_SAMPLING_RATE,
|
80 |
-
n_fft: int = DEFAULT_N_FFT,
|
81 |
-
win_length: Optional[int] = None,
|
82 |
-
hop_length: Optional[int] = None,
|
83 |
-
n_mels: int = DEFAULT_N_MELS,
|
84 |
-
f_min: float = 0.0,
|
85 |
-
f_max: Optional[float] = None,
|
86 |
-
padding_value: float = 0.0,
|
87 |
-
**kwargs
|
88 |
):
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
95 |
|
96 |
-
|
97 |
-
feature_size=n_mels,
|
98 |
-
sampling_rate=sampling_rate,
|
99 |
-
padding_value=padding_value,
|
100 |
-
**kwargs
|
101 |
-
)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
self.feat_stride = feat_stride
|
106 |
-
# self.sampling_rate is set by super() to the target rate
|
107 |
-
|
108 |
-
self.n_fft = n_fft
|
109 |
-
self.win_length = _win_length
|
110 |
-
self.hop_length = _hop_length
|
111 |
-
self.n_mels = n_mels
|
112 |
-
self.f_min = f_min
|
113 |
-
self.f_max = f_max if f_max is not None else self.sampling_rate / 2.0
|
114 |
-
|
115 |
-
if self.win_length > self.n_fft:
|
116 |
-
logger.warning(
|
117 |
-
f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
|
118 |
-
"Window will be applied, then data zero-padded/truncated to n_fft by np.fft.rfft."
|
119 |
-
)
|
120 |
-
self.window = np.hamming(self.win_length).astype(np.float32)
|
121 |
-
self.mel_filterbank = create_mel_filterbank(
|
122 |
-
self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
|
123 |
-
).T
|
124 |
|
125 |
-
|
126 |
-
self,
|
127 |
-
audios: Union[AudioInput, List[AudioInput]],
|
128 |
-
sampling_rate: Optional[int] = None,
|
129 |
-
return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
|
130 |
-
) -> BatchFeature:
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
processed_mels: List[torch.Tensor] = []
|
136 |
-
actual_mel_lengths: List[int] = []
|
137 |
-
sizes_for_downstream_calc: List[torch.Tensor] = []
|
138 |
-
frames_scaled_for_downstream_calc: List[int] = []
|
139 |
-
|
140 |
-
for audio_item in audios:
|
141 |
-
current_wav_array: np.ndarray
|
142 |
-
source_sr: int
|
143 |
-
|
144 |
-
if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
|
145 |
-
current_wav_array, source_sr = audio_item
|
146 |
-
current_wav_array = np.asarray(current_wav_array, dtype=np.float32)
|
147 |
-
elif isinstance(audio_item, (np.ndarray, list)):
|
148 |
-
current_wav_array = np.asarray(audio_item, dtype=np.float32)
|
149 |
-
if sampling_rate is None:
|
150 |
-
raise ValueError(
|
151 |
-
"sampling_rate argument must be provided to __call__ if 'audios' items "
|
152 |
-
"are raw numpy arrays or lists (without embedded sampling rate info)."
|
153 |
-
)
|
154 |
-
source_sr = sampling_rate
|
155 |
-
else:
|
156 |
-
raise TypeError(
|
157 |
-
f"Unsupported audio_item type: {type(audio_item)}. Expected np.ndarray, list of floats, "
|
158 |
-
"or Tuple[np.ndarray, int (sampling_rate)]."
|
159 |
-
)
|
160 |
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
|
165 |
-
# This could indicate an issue in _compute_log_mel_spectrogram or very unusual input.
|
166 |
-
# Depending on downstream requirements, this might need more robust error handling or a clear fallback.
|
167 |
-
pass # Allowing to proceed, but output shape might be unexpected.
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
|
176 |
-
|
177 |
-
|
|
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
}
|
187 |
|
188 |
-
|
189 |
-
output_data["audio_values_sizes"] = torch.stack(sizes_for_downstream_calc)
|
190 |
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
if np.issubdtype(wav.dtype, np.integer):
|
196 |
-
max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0
|
197 |
-
wav = wav.astype(np.float32) / max_val
|
198 |
-
else:
|
199 |
-
wav = wav.astype(np.float32)
|
200 |
-
elif wav.dtype == np.float64:
|
201 |
-
wav = wav.astype(np.float32)
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
common_divisor = math.gcd(self.sampling_rate, source_sr)
|
208 |
-
up_factor = self.sampling_rate // common_divisor
|
209 |
-
down_factor = source_sr // common_divisor
|
210 |
-
if up_factor != down_factor: # Avoid resampling if factors are identical
|
211 |
-
wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
|
212 |
-
|
213 |
-
max_abs_val = np.abs(wav).max()
|
214 |
-
if max_abs_val > 1e-7:
|
215 |
-
wav = wav / max_abs_val
|
216 |
-
return wav
|
217 |
-
|
218 |
-
def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
|
219 |
-
if len(wav) < self.win_length:
|
220 |
-
padding = self.win_length - len(wav)
|
221 |
-
wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
|
222 |
-
|
223 |
-
if len(wav) >= self.win_length:
|
224 |
-
num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
|
225 |
-
else:
|
226 |
-
num_frames = 0
|
227 |
|
228 |
-
|
229 |
-
|
|
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
)
|
237 |
-
frames_data = frames_view.copy()
|
238 |
-
frames_data *= self.window
|
239 |
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
log_mel_spectrogram = np.log(mel_spectrogram)
|
245 |
|
246 |
-
return log_mel_spectrogram.astype(np.float32)
|
247 |
|
248 |
-
|
249 |
-
compressed = math.ceil(frame_count / self.compression_rate)
|
250 |
-
return math.ceil(compressed / self.qformer_rate)
|
251 |
|
252 |
|
253 |
class Gemma3ImagesKwargs(ImagesKwargs):
|
@@ -280,7 +290,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
280 |
def __init__(
|
281 |
self,
|
282 |
image_processor=None,
|
283 |
-
audio_processor=None,
|
284 |
tokenizer=None,
|
285 |
chat_template=None,
|
286 |
image_seq_length: int = 256,
|
@@ -303,7 +313,8 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
303 |
self.image_token = getattr(self.tokenizer, "image_token", "<image>")
|
304 |
self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
|
305 |
|
306 |
-
self.audio_token_str_from_user_code = "<audio_soft_token>"
|
|
|
307 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
|
308 |
if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
|
309 |
logger.warning(
|
@@ -319,12 +330,14 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
319 |
self.image_token = "<image>"
|
320 |
self.eoi_token = ""
|
321 |
self.audio_token_str_from_user_code = "<audio_soft_token>"
|
322 |
-
self.audio_token_id = -1
|
323 |
self.full_image_sequence = ""
|
324 |
|
325 |
-
|
326 |
-
self.
|
327 |
-
self.
|
|
|
|
|
328 |
self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")
|
329 |
|
330 |
def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
|
@@ -339,14 +352,14 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
339 |
if modality_key_in_call in final_kwargs:
|
340 |
if isinstance(modality_kwargs_in_call, dict):
|
341 |
final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
|
342 |
-
elif isinstance(modality_kwargs_in_call, dict):
|
343 |
final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
|
344 |
|
345 |
-
if self.tokenizer:
|
346 |
for modality_key in final_kwargs:
|
347 |
modality_dict = final_kwargs[modality_key]
|
348 |
-
if isinstance(modality_dict, dict):
|
349 |
-
for key_in_mod_dict in list(modality_dict.keys()):
|
350 |
if key_in_mod_dict in tokenizer_init_kwargs:
|
351 |
value = (
|
352 |
getattr(self.tokenizer, key_in_mod_dict)
|
@@ -355,148 +368,218 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
355 |
)
|
356 |
modality_dict[key_in_mod_dict] = value
|
357 |
|
358 |
-
if "text_kwargs" not in final_kwargs:
|
359 |
-
final_kwargs["text_kwargs"] = {}
|
360 |
final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
|
361 |
final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
|
362 |
|
363 |
return final_kwargs
|
364 |
|
365 |
def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
|
366 |
-
|
367 |
-
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
def __call__(
|
371 |
self,
|
372 |
text: Union[str, List[str]] = None,
|
373 |
images: Optional[Any] = None,
|
374 |
audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
|
375 |
-
sampling_rate: Optional[int] = None,
|
376 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
377 |
**kwargs: Any
|
378 |
) -> BatchFeature:
|
379 |
if text is None and images is None and audios is None:
|
380 |
raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
|
381 |
|
382 |
-
final_rt = return_tensors
|
|
|
383 |
merged_call_kwargs = self._merge_kwargs(
|
384 |
-
Gemma3ProcessorKwargs,
|
385 |
-
self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
|
386 |
-
**kwargs
|
387 |
)
|
388 |
|
389 |
-
|
|
|
390 |
final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
|
391 |
-
else:
|
392 |
merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
|
393 |
|
394 |
-
if text is None:
|
395 |
num_samples = 0
|
396 |
if images is not None:
|
397 |
_images_list = images if isinstance(images, list) and (
|
398 |
not images or not isinstance(images[0], (int, float))) else [images]
|
399 |
num_samples = len(_images_list)
|
400 |
elif audios is not None:
|
401 |
-
_audios_list = audios if isinstance(audios, list)
|
|
|
|
|
402 |
num_samples = len(_audios_list)
|
403 |
-
text = [""] * num_samples if num_samples > 0 else [""]
|
404 |
|
405 |
-
if isinstance(text, str): text = [text]
|
406 |
if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
|
407 |
raise ValueError("Input `text` must be a string or a list of strings.")
|
408 |
|
409 |
image_features_dict = {}
|
410 |
if images is not None:
|
411 |
if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
|
412 |
-
|
|
|
|
|
|
|
413 |
_img_proc_output = self.image_processor(batched_images, return_tensors=None,
|
414 |
-
**
|
415 |
image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
|
416 |
BatchFeature) else _img_proc_output
|
417 |
|
418 |
-
if len(text) ==
|
419 |
-
|
420 |
-
|
421 |
-
|
|
|
|
|
|
|
422 |
|
423 |
num_crops_popped = image_features_dict.pop("num_crops", None)
|
424 |
if num_crops_popped is not None:
|
425 |
num_crops_all = to_py_obj(num_crops_popped)
|
426 |
temp_text_img, current_crop_idx_offset = [], 0
|
427 |
for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
|
428 |
-
crops_for_this_batch_sample = []
|
429 |
-
if num_crops_all:
|
430 |
-
for _ in current_imgs_in_batch:
|
431 |
if current_crop_idx_offset < len(num_crops_all):
|
432 |
-
|
433 |
-
|
|
|
|
|
434 |
else:
|
435 |
-
crops_for_this_batch_sample.append(0)
|
436 |
-
|
|
|
437 |
processed_prompt = prompt
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
|
|
|
|
|
|
|
|
|
|
450 |
temp_text_img.append(processed_prompt)
|
451 |
text = temp_text_img
|
452 |
-
|
|
|
|
|
453 |
|
454 |
audio_features_dict = {}
|
455 |
if audios is not None:
|
456 |
if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")
|
|
|
457 |
audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
|
|
|
458 |
if sampling_rate is not None: audio_call_kwargs["sampling_rate"] = sampling_rate
|
459 |
|
|
|
|
|
460 |
_audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
|
461 |
audio_features_dict = _audio_proc_output.data
|
462 |
|
463 |
-
new_text_with_audio
|
464 |
-
|
465 |
-
|
466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
for i, prompt in enumerate(text):
|
469 |
-
num_soft_tokens
|
|
|
|
|
|
|
470 |
audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens
|
471 |
|
472 |
if self.audio_placeholder_token in prompt:
|
473 |
-
prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str,
|
|
|
474 |
else:
|
475 |
-
prompt += audio_token_sequence_str
|
476 |
new_text_with_audio.append(prompt)
|
477 |
text = new_text_with_audio
|
478 |
|
479 |
text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
|
480 |
text_features_dict = self.tokenizer(text=text, return_tensors=None,
|
481 |
-
**text_tokenizer_kwargs)
|
482 |
|
|
|
483 |
input_ids_list_of_lists = text_features_dict["input_ids"]
|
|
|
484 |
if not isinstance(input_ids_list_of_lists, list) or not (
|
485 |
input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)):
|
486 |
if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
|
487 |
-
input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)
|
488 |
elif isinstance(input_ids_list_of_lists, list) and (
|
489 |
not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
|
490 |
-
input_ids_list_of_lists = [input_ids_list_of_lists]
|
491 |
|
492 |
token_type_ids_list = []
|
493 |
for ids_sample in input_ids_list_of_lists:
|
494 |
-
types = [0] * len(ids_sample)
|
495 |
for j, token_id_val in enumerate(ids_sample):
|
496 |
if self.image_token_id is not None and token_id_val == self.image_token_id:
|
497 |
-
types[j] = 1
|
498 |
-
elif self.audio_token_id != -1 and token_id_val == self.audio_token_id:
|
499 |
-
types[j] = 2
|
500 |
token_type_ids_list.append(types)
|
501 |
text_features_dict["token_type_ids"] = token_type_ids_list
|
502 |
|
@@ -504,6 +587,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
504 |
if image_features_dict: final_batch_data.update(image_features_dict)
|
505 |
if audio_features_dict: final_batch_data.update(audio_features_dict)
|
506 |
|
|
|
507 |
return BatchFeature(data=final_batch_data, tensor_type=final_rt)
|
508 |
|
509 |
def batch_decode(self, *args, **kwargs):
|
@@ -516,16 +600,29 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
516 |
def model_input_names(self) -> List[str]:
|
517 |
input_names = set()
|
518 |
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
|
521 |
if hasattr(self, 'image_processor') and self.image_processor is not None:
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
|
531 |
return list(input_names)
|
|
|
1 |
import re
|
2 |
+
from typing import List, Optional, Union, Dict, Any, Tuple # Added Tuple
|
3 |
|
4 |
import math
|
5 |
import numpy as np
|
6 |
import scipy.signal
|
7 |
import torch
|
8 |
from torch.nn.utils.rnn import pad_sequence
|
9 |
+
from transformers.audio_utils import AudioInput # type: ignore
|
10 |
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
11 |
from transformers.feature_extraction_utils import BatchFeature
|
12 |
+
from transformers.image_utils import make_nested_list_of_images # If image processing is used
|
13 |
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
|
14 |
from transformers.utils import TensorType, to_py_obj, logging
|
15 |
|
16 |
# Constants
|
17 |
DEFAULT_SAMPLING_RATE = 16000
|
18 |
DEFAULT_N_FFT = 512
|
19 |
+
DEFAULT_WIN_LENGTH = 400 # Matches Phi4M's 16kHz win_length for reference
|
20 |
+
DEFAULT_HOP_LENGTH = 160 # Matches Phi4M's 16kHz hop_length for reference
|
21 |
DEFAULT_N_MELS = 80
|
22 |
+
DEFAULT_COMPRESSION_RATE = 4 # Generic default
|
23 |
+
DEFAULT_QFORMER_RATE = 2 # Generic default
|
24 |
+
DEFAULT_FEAT_STRIDE = 4 # Generic default
|
25 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
26 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
27 |
DEFAULT_MAX_LENGTH = 16384
|
28 |
+
# LOG_MEL_CLIP_EPSILON = 1e-5 # Original B's constant, A clips at 1.0
|
29 |
|
30 |
logger = logging.get_logger(__name__)
|
31 |
|
32 |
|
33 |
+
# This create_mel_filterbank function is from your original Snippet B.
|
34 |
+
# It will be used by the Gemma3AudioFeatureExtractor.
|
35 |
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
|
36 |
fmax: Optional[float] = None) -> np.ndarray:
|
37 |
"""Create Mel filterbank for audio processing."""
|
38 |
fmax = fmax or sampling_rate / 2.0
|
39 |
|
40 |
+
def hz_to_mel(f: float) -> float: # Slaney scale from Snippet B
|
41 |
return 1127.0 * math.log(1 + f / 700.0)
|
42 |
|
43 |
if fmin >= fmax:
|
44 |
raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
|
45 |
|
46 |
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
47 |
+
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) # Inverse of Slaney hz_to_mel
|
48 |
|
49 |
freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
|
50 |
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
|
51 |
+
bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output is n_fft//2
|
52 |
|
53 |
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
|
54 |
for m_idx in range(n_mels):
|
55 |
left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
|
56 |
|
57 |
+
if center > left: # Rising slope
|
58 |
filterbank[m_idx, left:center + 1] = (np.arange(left, center + 1) - left) / (center - left)
|
59 |
+
if right > center: # Falling slope
|
60 |
filterbank[m_idx, center:right + 1] = (right - np.arange(center, right + 1)) / (right - center)
|
61 |
|
62 |
# Ensure the peak at 'center' is 1.0 if it's a valid point.
|
63 |
+
# This logic is from original Snippet B. Phi4M's speechlib_mel might normalize differently.
|
64 |
+
if left <= center <= right: # Check if center is within the bounds of the filter
|
65 |
+
if filterbank.shape[1] > center: # Check if center index is within filterbank columns
|
66 |
+
if (center > left and filterbank[m_idx, center] < 1.0 and center < right) or \
|
67 |
+
(left == center and center < right) or \
|
68 |
+
(right == center and left < center): # Ensure it's a triangular filter with a slope
|
69 |
filterbank[m_idx, center] = 1.0
|
70 |
+
elif left == center and right == center: # Handles the case of a filter with zero width if bins are identical
|
71 |
+
filterbank[m_idx, center] = 1.0
|
72 |
+
|
73 |
return filterbank
|
74 |
|
75 |
|
76 |
+
# --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
|
77 |
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
78 |
+
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
|
79 |
+
|
80 |
+
def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs):
|
81 |
+
feature_size = 80 # From Phi4M
|
82 |
+
sampling_rate = 16000 # From Phi4M (target sampling rate)
|
83 |
+
padding_value = 0.0 # From Phi4M
|
84 |
+
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
85 |
+
|
86 |
+
self.compression_rate = audio_compression_rate
|
87 |
+
self.qformer_compression_rate = audio_downsample_rate # In Phi4M, audio_downsample_rate is qformer_compression_rate
|
88 |
+
self.feat_stride = audio_feat_stride
|
89 |
+
|
90 |
+
self._eightk_method = kwargs.get("eightk_method", "fillzero") # 'fillzero' or 'resample'
|
91 |
+
|
92 |
+
# Using the provided create_mel_filterbank (Slaney scale)
|
93 |
+
# Parameters for Mel filterbank match Phi4M's speechlib_mel call
|
94 |
+
self._mel = create_mel_filterbank(
|
95 |
+
sampling_rate=16000, # Target sampling rate
|
96 |
+
n_fft=512, # n_fft for 16kHz audio in Phi4M
|
97 |
+
n_mels=80, # feature_size
|
98 |
+
fmin=0.0, # Phi4M's fmin is None, typically defaults to 0
|
99 |
+
fmax=7690.0 # Specific fmax from Phi4M
|
100 |
+
).T
|
101 |
+
self._hamming400 = np.hamming(400) # for 16k audio, from Phi4M
|
102 |
+
self._hamming200 = np.hamming(200) # for 8k audio, from Phi4M
|
103 |
|
104 |
+
def __call__(
|
105 |
self,
|
106 |
+
audios: List[Union[AudioInput, Tuple[np.ndarray, int]]], # More specific type hint
|
107 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
):
|
109 |
+
returned_input_audio_embeds = []
|
110 |
+
returned_audio_embed_sizes = []
|
111 |
+
audio_frames_list = [] # Stores num_mel_frames * feat_stride for each audio item
|
112 |
+
|
113 |
+
for audio_input_item in audios:
|
114 |
+
if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2:
|
115 |
+
raise ValueError(
|
116 |
+
"Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)."
|
117 |
+
)
|
118 |
+
audio_data, sample_rate = audio_input_item
|
119 |
|
120 |
+
if isinstance(audio_data, list): # Convert list to ndarray
|
121 |
+
audio_data = np.array(audio_data, dtype=np.float32)
|
122 |
+
if not isinstance(audio_data, np.ndarray):
|
123 |
+
raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
|
124 |
|
125 |
+
audio_embeds_np = self._extract_features(audio_data, sample_rate) # log_fbank
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
num_mel_frames = audio_embeds_np.shape[0]
|
128 |
+
current_audio_frames = num_mel_frames * self.feat_stride # Phi4M logic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
returned_input_audio_embeds.append(torch.from_numpy(audio_embeds_np))
|
133 |
+
returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
|
134 |
+
audio_frames_list.append(current_audio_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
padded_input_audio_embeds = pad_sequence(
|
137 |
+
returned_input_audio_embeds, batch_first=True, padding_value=self.padding_value
|
138 |
+
)
|
139 |
+
stacked_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
|
140 |
+
|
141 |
+
tensor_audio_frames_list = torch.tensor(audio_frames_list, dtype=torch.long)
|
142 |
+
|
143 |
+
max_audio_frames = 0
|
144 |
+
if len(audios) > 0 and tensor_audio_frames_list.numel() > 0:
|
145 |
+
max_audio_frames = tensor_audio_frames_list.max().item()
|
146 |
+
|
147 |
+
returned_audio_attention_mask = None
|
148 |
+
if max_audio_frames > 0: # Create mask only if there are frames
|
149 |
+
if len(audios) > 1:
|
150 |
+
returned_audio_attention_mask = torch.arange(0, max_audio_frames,
|
151 |
+
device=tensor_audio_frames_list.device).unsqueeze(
|
152 |
+
0) < tensor_audio_frames_list.unsqueeze(1)
|
153 |
+
elif len(audios) == 1: # For batch size 1
|
154 |
+
returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool,
|
155 |
+
device=tensor_audio_frames_list.device)
|
156 |
+
|
157 |
+
data = {
|
158 |
+
"input_audio_embeds": padded_input_audio_embeds,
|
159 |
+
"audio_embed_sizes": stacked_audio_embed_sizes,
|
160 |
+
}
|
161 |
+
if returned_audio_attention_mask is not None:
|
162 |
+
data["audio_attention_mask"] = returned_audio_attention_mask
|
163 |
|
164 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
|
|
|
|
|
165 |
|
166 |
+
def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
167 |
+
if wav.ndim > 1:
|
168 |
+
wav = np.squeeze(wav)
|
169 |
+
if len(wav.shape) == 2: # stereo to mono
|
170 |
+
wav = wav.mean(axis=1).astype(np.float32) # Ensure float32 after mean
|
171 |
+
|
172 |
+
wav = wav.astype(np.float32) # Ensure wav is float32
|
173 |
+
|
174 |
+
# Phi4M Resampling logic
|
175 |
+
if fs > self.sampling_rate: # self.sampling_rate is 16000
|
176 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, fs)
|
177 |
+
fs = self.sampling_rate
|
178 |
+
elif 8000 < fs < self.sampling_rate:
|
179 |
+
wav = scipy.signal.resample_poly(wav, 8000, fs) # Resample to 8000 first
|
180 |
+
fs = 8000
|
181 |
+
elif fs < 8000 and fs > 0:
|
182 |
+
logger.warning(f"Sample rate {fs} is less than 8000Hz. Resampling to 8000Hz.")
|
183 |
+
wav = scipy.signal.resample_poly(wav, 8000, fs)
|
184 |
+
fs = 8000
|
185 |
+
elif fs <= 0:
|
186 |
+
raise RuntimeError(f"Unsupported sample rate {fs}")
|
187 |
+
|
188 |
+
if fs == 8000:
|
189 |
+
if self._eightk_method == "resample":
|
190 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000) # Resample 8k to 16k
|
191 |
+
fs = self.sampling_rate
|
192 |
+
# If "fillzero", parameters for 8k will be used, and spectrum padded later.
|
193 |
+
elif fs != self.sampling_rate: # Should be 16000 if not 8000 and _eightk_method != "resample"
|
194 |
+
raise RuntimeError(
|
195 |
+
f"Audio sample rate {fs} not supported after initial processing. Expected {self.sampling_rate} or 8000.")
|
196 |
+
|
197 |
+
preemphasis_coeff = 0.97
|
198 |
+
|
199 |
+
if fs == 8000:
|
200 |
+
n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
|
201 |
+
elif fs == 16000:
|
202 |
+
n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
|
203 |
+
else:
|
204 |
+
raise RuntimeError(f"Inconsistent fs {fs} for parameter selection.")
|
205 |
|
206 |
+
if len(wav) < win_length:
|
207 |
+
wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
|
208 |
|
209 |
+
num_frames = (wav.shape[0] - win_length) // hop_length + 1
|
210 |
+
if num_frames <= 0:
|
211 |
+
return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
|
212 |
|
213 |
+
y_frames = np.array(
|
214 |
+
[wav[i * hop_length: i * hop_length + win_length] for i in range(num_frames)],
|
215 |
+
dtype=np.float32,
|
216 |
+
)
|
217 |
|
218 |
+
_y_frames_rolled = np.roll(y_frames, 1, axis=1)
|
219 |
+
_y_frames_rolled[:, 0] = _y_frames_rolled[:, 1] # Phi4M specific handling
|
220 |
+
y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
|
|
|
221 |
|
222 |
+
S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
|
|
|
223 |
|
224 |
+
if fs == 8000 and self._eightk_method == "fillzero":
|
225 |
+
# Pad spectrum to match 16kHz feature dimension (n_fft=512 -> 257 bins)
|
226 |
+
# Current S has (256 // 2) + 1 = 129 bins
|
227 |
+
target_bins = (512 // 2) + 1
|
228 |
+
pad_width = target_bins - S.shape[1]
|
229 |
+
# Phi4M: S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
|
230 |
+
# This means take all but last bin from 8k spectrum, then pad.
|
231 |
+
S_core = S[:, :-1]
|
232 |
+
padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
|
233 |
+
S = np.concatenate((S_core, padarray), axis=1)
|
234 |
|
235 |
+
spec = np.abs(S).astype(np.float32)
|
236 |
+
return spec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
+
def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
239 |
+
spec = self._extract_spectrogram(wav, fs)
|
240 |
+
if spec.shape[0] == 0:
|
241 |
+
return np.zeros((0, self.feature_size), dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
+
spec_power = spec ** 2
|
244 |
+
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) # Clip at 1.0 before log (Phi4M)
|
245 |
+
log_fbank = np.log(fbank_power).astype(np.float32)
|
246 |
+
return log_fbank
|
247 |
|
248 |
+
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
249 |
+
# Phi4M's logic for compressed size
|
250 |
+
integer = audio_frames // self.compression_rate
|
251 |
+
remainder = audio_frames % self.compression_rate
|
252 |
+
result = integer if remainder == 0 else integer + 1
|
|
|
|
|
|
|
253 |
|
254 |
+
integer = result // self.qformer_compression_rate
|
255 |
+
remainder = result % self.qformer_compression_rate
|
256 |
+
result = integer if remainder == 0 else integer + 1
|
257 |
+
return result
|
|
|
258 |
|
|
|
259 |
|
260 |
+
# --- End of Refactored Audio Feature Extractor ---
|
|
|
|
|
261 |
|
262 |
|
263 |
class Gemma3ImagesKwargs(ImagesKwargs):
|
|
|
290 |
def __init__(
|
291 |
self,
|
292 |
image_processor=None,
|
293 |
+
audio_processor=None, # User can pass an instance of RefactoredGemma3... here
|
294 |
tokenizer=None,
|
295 |
chat_template=None,
|
296 |
image_seq_length: int = 256,
|
|
|
313 |
self.image_token = getattr(self.tokenizer, "image_token", "<image>")
|
314 |
self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
|
315 |
|
316 |
+
self.audio_token_str_from_user_code = "<audio_soft_token>" # Example
|
317 |
+
# Ensure this token is actually in the tokenizer vocab as a special token
|
318 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
|
319 |
if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
|
320 |
logger.warning(
|
|
|
330 |
self.image_token = "<image>"
|
331 |
self.eoi_token = ""
|
332 |
self.audio_token_str_from_user_code = "<audio_soft_token>"
|
333 |
+
self.audio_token_id = -1 # Placeholder if tokenizer is missing
|
334 |
self.full_image_sequence = ""
|
335 |
|
336 |
+
# These attributes are specific to Gemma3OmniProcessor for its internal _compute_audio_embed_size
|
337 |
+
self.prompt_audio_compression_rate = kwargs.pop("prompt_audio_compression_rate", DEFAULT_COMPRESSION_RATE)
|
338 |
+
self.prompt_audio_qformer_rate = kwargs.pop("prompt_audio_qformer_rate", DEFAULT_QFORMER_RATE)
|
339 |
+
# self.prompt_audio_feat_stride = kwargs.pop("prompt_audio_feat_stride", DEFAULT_FEAT_STRIDE) # Not used by its _compute_audio_embed_size
|
340 |
+
|
341 |
self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")
|
342 |
|
343 |
def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
|
|
|
352 |
if modality_key_in_call in final_kwargs:
|
353 |
if isinstance(modality_kwargs_in_call, dict):
|
354 |
final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
|
355 |
+
elif isinstance(modality_kwargs_in_call, dict): # New modality not in defaults
|
356 |
final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
|
357 |
|
358 |
+
if self.tokenizer: # Ensure tokenizer exists before accessing its attributes
|
359 |
for modality_key in final_kwargs:
|
360 |
modality_dict = final_kwargs[modality_key]
|
361 |
+
if isinstance(modality_dict, dict): # Check if it's a dictionary
|
362 |
+
for key_in_mod_dict in list(modality_dict.keys()): # Iterate over keys
|
363 |
if key_in_mod_dict in tokenizer_init_kwargs:
|
364 |
value = (
|
365 |
getattr(self.tokenizer, key_in_mod_dict)
|
|
|
368 |
)
|
369 |
modality_dict[key_in_mod_dict] = value
|
370 |
|
371 |
+
if "text_kwargs" not in final_kwargs: final_kwargs["text_kwargs"] = {} # Ensure text_kwargs exists
|
|
|
372 |
final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
|
373 |
final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
|
374 |
|
375 |
return final_kwargs
|
376 |
|
377 |
def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
|
378 |
+
# This method is part of Gemma3OmniProcessor.
|
379 |
+
# It calculates a number of soft tokens based on its own compression rates.
|
380 |
+
# Note: `audio_mel_frames` here is the number of raw Mel frames from the feature extractor's perspective
|
381 |
+
# if the attention mask sum is directly used before feat_stride scaling by the processor.
|
382 |
+
# However, if using the Refactored processor, audio_attention_mask.sum() will yield
|
383 |
+
# num_mel_frames * feat_stride. This method should then correctly compress that value.
|
384 |
+
|
385 |
+
# Using prompt_audio_compression_rate and prompt_audio_qformer_rate
|
386 |
+
# which are attributes of this Gemma3OmniProcessor class.
|
387 |
+
|
388 |
+
# First compression
|
389 |
+
# audio_mel_frames here should ideally be num_actual_mel_frames * feat_stride_of_the_audio_processor
|
390 |
+
# if trying to match the number of tokens from a Phi4M-style processor.
|
391 |
+
# The refactored audio processor does this scaling internally before its own _compute_audio_embed_size.
|
392 |
+
# If actual_mel_frames_per_sample (from sum of attention_mask) *is* already scaled by feat_stride
|
393 |
+
# (as it would be if using the refactored processor's attention_mask), then this calculation is correct.
|
394 |
+
|
395 |
+
integer = audio_mel_frames // self.prompt_audio_compression_rate
|
396 |
+
remainder = audio_mel_frames % self.prompt_audio_compression_rate
|
397 |
+
result = integer if remainder == 0 else integer + 1
|
398 |
+
|
399 |
+
# Second compression
|
400 |
+
integer = result // self.prompt_audio_qformer_rate
|
401 |
+
remainder = result % self.prompt_audio_qformer_rate
|
402 |
+
result = integer if remainder == 0 else integer + 1
|
403 |
+
return result
|
404 |
|
405 |
def __call__(
|
406 |
self,
|
407 |
text: Union[str, List[str]] = None,
|
408 |
images: Optional[Any] = None,
|
409 |
audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
|
410 |
+
sampling_rate: Optional[int] = None, # sampling_rate for raw audio arrays
|
411 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
412 |
**kwargs: Any
|
413 |
) -> BatchFeature:
|
414 |
if text is None and images is None and audios is None:
|
415 |
raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
|
416 |
|
417 |
+
final_rt = return_tensors # Store original return_tensors
|
418 |
+
# Properly merge kwargs for text, images, audio
|
419 |
merged_call_kwargs = self._merge_kwargs(
|
420 |
+
Gemma3ProcessorKwargs, # The class defining _defaults
|
421 |
+
self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {}, # Tokenizer defaults
|
422 |
+
**kwargs # User-provided kwargs from the call
|
423 |
)
|
424 |
|
425 |
+
# Determine final return_tensors, prioritizing call > text_kwargs > default
|
426 |
+
if final_rt is None: # If not specified in call
|
427 |
final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
|
428 |
+
else: # If specified in call, remove from text_kwargs to avoid conflict
|
429 |
merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
|
430 |
|
431 |
+
if text is None: # If no text, create empty strings based on other inputs
|
432 |
num_samples = 0
|
433 |
if images is not None:
|
434 |
_images_list = images if isinstance(images, list) and (
|
435 |
not images or not isinstance(images[0], (int, float))) else [images]
|
436 |
num_samples = len(_images_list)
|
437 |
elif audios is not None:
|
438 |
+
_audios_list = audios if isinstance(audios, list) and not (
|
439 |
+
isinstance(audios[0], tuple) and isinstance(audios[0][0], (int, float))) else [
|
440 |
+
audios] # check if audios is list of items or list of (wave,sr)
|
441 |
num_samples = len(_audios_list)
|
442 |
+
text = [""] * num_samples if num_samples > 0 else [""] # Default to one empty string if no inputs
|
443 |
|
444 |
+
if isinstance(text, str): text = [text] # Ensure text is a list
|
445 |
if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
|
446 |
raise ValueError("Input `text` must be a string or a list of strings.")
|
447 |
|
448 |
image_features_dict = {}
|
449 |
if images is not None:
|
450 |
if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
|
451 |
+
# Ensure images are correctly batched
|
452 |
+
batched_images = make_nested_list_of_images(images) # handles various image input types
|
453 |
+
|
454 |
+
_img_kwargs = merged_call_kwargs.get("images_kwargs", {})
|
455 |
_img_proc_output = self.image_processor(batched_images, return_tensors=None,
|
456 |
+
**_img_kwargs) # Pass None to handle tensors later
|
457 |
image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
|
458 |
BatchFeature) else _img_proc_output
|
459 |
|
460 |
+
if len(text) == 1 and text[0] == "" and len(
|
461 |
+
batched_images) > 0: # If text is default empty and images exist
|
462 |
+
text = [" ".join([self.boi_token] * len(img_batch)) for img_batch in batched_images]
|
463 |
+
elif len(batched_images) != len(text): # If text was provided, ensure consistency
|
464 |
+
raise ValueError(
|
465 |
+
f"Inconsistent batch: {len(batched_images)} image groups, {len(text)} texts. Ensure one text prompt per image group."
|
466 |
+
)
|
467 |
|
468 |
num_crops_popped = image_features_dict.pop("num_crops", None)
|
469 |
if num_crops_popped is not None:
|
470 |
num_crops_all = to_py_obj(num_crops_popped)
|
471 |
temp_text_img, current_crop_idx_offset = [], 0
|
472 |
for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
|
473 |
+
crops_for_this_batch_sample = [] # Number of *additional* crops for each original image
|
474 |
+
if num_crops_all: # If num_crops_all is not None or empty
|
475 |
+
for _ in current_imgs_in_batch: # For each original image in the current batch sample
|
476 |
if current_crop_idx_offset < len(num_crops_all):
|
477 |
+
# num_crops_all contains total items (original + crops) for each image
|
478 |
+
# We need number of *additional* crops. Assuming num_crops_all[i] >= 1
|
479 |
+
crops_for_this_batch_sample.append(max(0, num_crops_all[current_crop_idx_offset] - 1))
|
480 |
+
current_crop_idx_offset += 1
|
481 |
else:
|
482 |
+
crops_for_this_batch_sample.append(0) # Should not happen if num_crops_all is correct
|
483 |
+
|
484 |
+
image_placeholders_in_prompt = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
|
485 |
processed_prompt = prompt
|
486 |
+
|
487 |
+
# Iterate backwards to preserve indices for replacement
|
488 |
+
iter_count = min(len(crops_for_this_batch_sample), len(image_placeholders_in_prompt))
|
489 |
+
for i_placeholder_idx in range(iter_count - 1, -1, -1):
|
490 |
+
num_additional_crops_for_this_image = crops_for_this_batch_sample[i_placeholder_idx]
|
491 |
+
original_token_idx_in_prompt = image_placeholders_in_prompt[i_placeholder_idx]
|
492 |
+
|
493 |
+
if num_additional_crops_for_this_image > 0:
|
494 |
+
# Create replacement text: original image placeholder + placeholders for additional crops
|
495 |
+
replacement_text = self.boi_token + "".join(
|
496 |
+
[self.boi_token] * num_additional_crops_for_this_image)
|
497 |
+
# Replace the single original boi_token with the new sequence
|
498 |
+
processed_prompt = (
|
499 |
+
processed_prompt[:original_token_idx_in_prompt] +
|
500 |
+
replacement_text +
|
501 |
+
processed_prompt[original_token_idx_in_prompt + len(self.boi_token):]
|
502 |
+
)
|
503 |
temp_text_img.append(processed_prompt)
|
504 |
text = temp_text_img
|
505 |
+
# Replace all BOI tokens with the full image sequence (BOI + IMAGE*N + EOI)
|
506 |
+
# This step assumes that if additional crops were handled, self.boi_token still marks each image.
|
507 |
+
text = [p.replace(self.boi_token, self.full_image_sequence) for p in text]
|
508 |
|
509 |
audio_features_dict = {}
|
510 |
if audios is not None:
|
511 |
if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")
|
512 |
+
|
513 |
audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
|
514 |
+
# Pass sampling_rate from __call__ to audio_processor if provided (for raw arrays)
|
515 |
if sampling_rate is not None: audio_call_kwargs["sampling_rate"] = sampling_rate
|
516 |
|
517 |
+
# The audio_processor (e.g., RefactoredGemma3...) will return its model_input_names
|
518 |
+
# e.g., {"input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"}
|
519 |
_audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
|
520 |
audio_features_dict = _audio_proc_output.data
|
521 |
|
522 |
+
new_text_with_audio = []
|
523 |
+
|
524 |
+
# Determine the number of actual audio items processed by the audio_processor
|
525 |
+
# This should match len(text) if batching is consistent.
|
526 |
+
# The 'audio_attention_mask' or 'input_audio_embeds' can indicate this.
|
527 |
+
num_audio_samples_processed = audio_features_dict[self.audio_processor.model_input_names[0]].shape[0]
|
528 |
+
|
529 |
+
if num_audio_samples_processed != len(text):
|
530 |
+
raise ValueError(
|
531 |
+
f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts."
|
532 |
+
)
|
533 |
+
|
534 |
+
# If using Gemma3AudioFeatureExtractor,
|
535 |
+
# "audio_embed_sizes" is already computed correctly (num compressed tokens).
|
536 |
+
# The processor's own _compute_audio_embed_size is called to determine how many
|
537 |
+
# self.audio_token_str_from_user_code to insert. Ideally, this matches.
|
538 |
+
|
539 |
+
# Get the number of frames that the processor's _compute_audio_embed_size expects.
|
540 |
+
# If the audio_processor is RefactoredGemma3..., its attention_mask is over (num_mel_frames * feat_stride).
|
541 |
+
# So, sum of that mask gives the input for this processor's _compute_audio_embed_size.
|
542 |
+
frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum(
|
543 |
+
axis=-1)) # sum of audio_attention_mask
|
544 |
|
545 |
for i, prompt in enumerate(text):
|
546 |
+
# num_soft_tokens should be the final number of audio tokens to insert in the text.
|
547 |
+
# This is calculated by the Gemma3OmniProcessor's own method.
|
548 |
+
num_soft_tokens = self._compute_audio_embed_size(frames_for_embed_size_calc[i])
|
549 |
+
|
550 |
audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens
|
551 |
|
552 |
if self.audio_placeholder_token in prompt:
|
553 |
+
prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str,
|
554 |
+
1) # Replace only first
|
555 |
else:
|
556 |
+
prompt += audio_token_sequence_str # Append if no placeholder
|
557 |
new_text_with_audio.append(prompt)
|
558 |
text = new_text_with_audio
|
559 |
|
560 |
text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
|
561 |
text_features_dict = self.tokenizer(text=text, return_tensors=None,
|
562 |
+
**text_tokenizer_kwargs) # Pass None for tensors
|
563 |
|
564 |
+
# Create token_type_ids
|
565 |
input_ids_list_of_lists = text_features_dict["input_ids"]
|
566 |
+
# Ensure it's a list of lists
|
567 |
if not isinstance(input_ids_list_of_lists, list) or not (
|
568 |
input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)):
|
569 |
if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
|
570 |
+
input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists) # to nested python lists
|
571 |
elif isinstance(input_ids_list_of_lists, list) and (
|
572 |
not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
|
573 |
+
input_ids_list_of_lists = [input_ids_list_of_lists] # wrap single list
|
574 |
|
575 |
token_type_ids_list = []
|
576 |
for ids_sample in input_ids_list_of_lists:
|
577 |
+
types = [0] * len(ids_sample) # 0 for text
|
578 |
for j, token_id_val in enumerate(ids_sample):
|
579 |
if self.image_token_id is not None and token_id_val == self.image_token_id:
|
580 |
+
types[j] = 1 # 1 for image
|
581 |
+
elif self.audio_token_id != -1 and token_id_val == self.audio_token_id: # Check if audio_token_id is valid
|
582 |
+
types[j] = 2 # 2 for audio
|
583 |
token_type_ids_list.append(types)
|
584 |
text_features_dict["token_type_ids"] = token_type_ids_list
|
585 |
|
|
|
587 |
if image_features_dict: final_batch_data.update(image_features_dict)
|
588 |
if audio_features_dict: final_batch_data.update(audio_features_dict)
|
589 |
|
590 |
+
# Convert all data to tensors if final_rt is specified
|
591 |
return BatchFeature(data=final_batch_data, tensor_type=final_rt)
|
592 |
|
593 |
def batch_decode(self, *args, **kwargs):
|
|
|
600 |
def model_input_names(self) -> List[str]:
|
601 |
input_names = set()
|
602 |
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
603 |
+
# Make sure model_input_names is a list/set before +
|
604 |
+
tokenizer_inputs = self.tokenizer.model_input_names
|
605 |
+
if isinstance(tokenizer_inputs, (list, set)):
|
606 |
+
input_names.update(tokenizer_inputs)
|
607 |
+
else: # Fallback if it's a single string
|
608 |
+
input_names.add(str(tokenizer_inputs))
|
609 |
+
input_names.add("token_type_ids")
|
610 |
|
611 |
if hasattr(self, 'image_processor') and self.image_processor is not None:
|
612 |
+
# Similar check for image_processor
|
613 |
+
image_inputs = self.image_processor.model_input_names
|
614 |
+
if isinstance(image_inputs, (list, set)):
|
615 |
+
input_names.update(image_inputs)
|
616 |
+
else:
|
617 |
+
input_names.add(str(image_inputs))
|
618 |
+
|
619 |
+
if hasattr(self, 'audio_processor') and self.audio_processor is not None:
|
620 |
+
# Use model_input_names from the instantiated audio_processor
|
621 |
+
# This will correctly reflect the names from RefactoredGemma3... if it's used.
|
622 |
+
audio_inputs = self.audio_processor.model_input_names
|
623 |
+
if isinstance(audio_inputs, (list, set)):
|
624 |
+
input_names.update(audio_inputs)
|
625 |
+
else:
|
626 |
+
input_names.add(str(audio_inputs))
|
627 |
|
628 |
return list(input_names)
|