File size: 32,756 Bytes
129c05b
3fa62c9
19649d5
129c05b
6cb34fb
129c05b
 
3fa62c9
129c05b
 
3fa62c9
eaad0f5
129c05b
 
9faac02
129c05b
 
ddf58eb
 
129c05b
460868e
 
9e58a2b
315e5b5
 
 
129c05b
 
 
9e58a2b
701891b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129c05b
d0ae9f9
3fa62c9
ddf58eb
3fa62c9
 
ddf58eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fa62c9
 
ddf58eb
3fa62c9
 
ddf58eb
 
 
 
 
 
 
 
 
 
 
3fa62c9
701891b
ddf58eb
 
eaad0f5
3fa62c9
7a85e64
ddf58eb
3fa62c9
ddf58eb
129c05b
3fa62c9
 
ddf58eb
3fa62c9
 
 
 
 
 
ddf58eb
315e5b5
ddf58eb
3fa62c9
 
 
315e5b5
ddf58eb
 
9faac02
3fa62c9
ddf58eb
129c05b
3fa62c9
0173a9f
3fa62c9
 
 
eaad0f5
3fa62c9
 
 
 
 
 
 
 
 
 
 
 
ddf58eb
3fa62c9
 
 
 
ddf58eb
3fa62c9
 
 
 
 
 
 
 
 
eaad0f5
3fa62c9
315e5b5
3fa62c9
ddf58eb
 
3fa62c9
 
ddf58eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fa62c9
ddf58eb
 
 
 
3fa62c9
ddf58eb
3fa62c9
 
 
ddf58eb
 
3fa62c9
ddf58eb
3fa62c9
 
ddf58eb
9faac02
3fa62c9
 
eaad0f5
3fa62c9
 
ddf58eb
 
 
3fa62c9
eaad0f5
3fa62c9
 
 
 
0173a9f
3fa62c9
ddf58eb
3fa62c9
d0ae9f9
3fa62c9
eaad0f5
ddf58eb
 
3fa62c9
ddf58eb
 
3fa62c9
 
129c05b
3fa62c9
 
eaad0f5
3fa62c9
 
 
ddf58eb
3fa62c9
eaad0f5
3fa62c9
ddf58eb
3fa62c9
 
eaad0f5
3fa62c9
 
 
 
0173a9f
3fa62c9
 
 
 
eaad0f5
129c05b
41e945a
9faac02
 
 
 
 
 
 
41e945a
eaad0f5
 
 
129c05b
 
 
 
 
 
d0ae9f9
129c05b
7a85e64
41e945a
eaad0f5
41e945a
315e5b5
41e945a
129c05b
 
 
eaad0f5
3fa62c9
eaad0f5
129c05b
6cb34fb
5fc5a97
129c05b
 
 
 
 
 
41e945a
129c05b
e475890
eaad0f5
 
 
 
 
315e5b5
eaad0f5
315e5b5
eaad0f5
3fa62c9
 
eaad0f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fa62c9
eaad0f5
 
3fa62c9
 
 
 
 
eaad0f5
 
 
 
 
 
 
 
 
 
 
315e5b5
eaad0f5
 
3fa62c9
eaad0f5
 
3fa62c9
eaad0f5
 
3fa62c9
 
315e5b5
eaad0f5
315e5b5
eaad0f5
 
 
 
 
3fa62c9
eaad0f5
 
 
 
 
 
3fa62c9
 
 
 
 
 
 
 
 
129c05b
 
 
eaad0f5
 
 
3fa62c9
eaad0f5
 
129c05b
eaad0f5
 
 
3fa62c9
 
eaad0f5
3fa62c9
 
 
129c05b
e475890
3fa62c9
 
eaad0f5
3fa62c9
eaad0f5
 
3fa62c9
eaad0f5
 
 
9e58a2b
eaad0f5
 
3fa62c9
9e58a2b
3fa62c9
eaad0f5
3fa62c9
eaad0f5
3fa62c9
eaad0f5
 
 
 
41e945a
eaad0f5
3fa62c9
 
 
 
eaad0f5
3fa62c9
eaad0f5
 
 
3fa62c9
 
 
 
 
 
 
eaad0f5
 
 
 
 
 
3fa62c9
 
 
eaad0f5
3fa62c9
 
 
 
eaad0f5
3fa62c9
 
 
eaad0f5
3fa62c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaad0f5
 
3fa62c9
 
 
eaad0f5
 
 
 
3fa62c9
eaad0f5
3fa62c9
315e5b5
eaad0f5
3fa62c9
 
eaad0f5
 
 
3fa62c9
 
 
 
 
 
 
 
 
 
 
 
 
eaad0f5
 
3fa62c9
 
 
 
315e5b5
eaad0f5
 
3fa62c9
 
eaad0f5
3fa62c9
eaad0f5
 
 
 
 
3fa62c9
eaad0f5
3fa62c9
eaad0f5
3fa62c9
eaad0f5
 
 
3fa62c9
eaad0f5
 
3fa62c9
eaad0f5
 
 
3fa62c9
eaad0f5
 
3fa62c9
 
 
eaad0f5
 
 
 
 
 
 
3fa62c9
eaad0f5
0173a9f
129c05b
 
 
 
 
 
 
eaad0f5
 
 
3fa62c9
 
 
 
 
 
 
eaad0f5
 
3fa62c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaad0f5
9e58a2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
import re
from typing import List, Optional, Union, Dict, Any, Tuple  # Added Tuple

import numpy as np
import scipy.signal
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers.audio_utils import AudioInput  # type: ignore
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import make_nested_list_of_images  # If image processing is used
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
from transformers.utils import TensorType, to_py_obj, logging

# Constants
DEFAULT_SAMPLING_RATE = 16000
DEFAULT_N_FFT = 512
DEFAULT_WIN_LENGTH = 400
DEFAULT_HOP_LENGTH = 160
DEFAULT_N_MELS = 80
DEFAULT_COMPRESSION_RATE = 4
DEFAULT_QFORMER_RATE = 4  # Used for default in __init__ (as audio_downsample_rate)
DEFAULT_FEAT_STRIDE = 4  # Used for default in __init__
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
DEFAULT_MAX_LENGTH = 16384

logger = logging.get_logger(__name__)


def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
    """Create a Mel filter-bank the same as SpeechLib FbankFC.
    Args:
        sample_rate (int): Sample rate in Hz. number > 0 [scalar]
        n_fft (int): FFT size. int > 0 [scalar]
        n_mel (int): Mel filter size. int > 0 [scalar]
        fmin (float): lowest frequency (in Hz). If None use 0.0.
            float >= 0 [scalar]
        fmax: highest frequency (in Hz). If None use sample_rate / 2.
            float >= 0 [scalar]
    Returns
        out (numpy.ndarray): Mel transform matrix
            [shape=(n_mels, 1 + n_fft/2)]
    """

    bank_width = int(n_fft // 2 + 1)
    if fmax is None:
        fmax = sample_rate / 2
    if fmin is None:
        fmin = 0
    assert fmin >= 0, "fmin cannot be negtive"
    assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"

    def mel(f):
        return 1127.0 * np.log(1.0 + f / 700.0)

    def bin2mel(fft_bin):
        return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))

    def f2bin(f):
        return int((f * n_fft / sample_rate) + 0.5)

    # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
    klo = f2bin(fmin) + 1
    khi = f2bin(fmax)

    khi = max(khi, klo)

    # Spec 2: SpeechLib uses trianges in Mel space
    mlo = mel(fmin)
    mhi = mel(fmax)
    m_centers = np.linspace(mlo, mhi, n_mels + 2)
    ms = (mhi - mlo) / (n_mels + 1)

    matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
    for m in range(0, n_mels):
        left = m_centers[m]
        center = m_centers[m + 1]
        right = m_centers[m + 2]
        for fft_bin in range(klo, khi):
            mbin = bin2mel(fft_bin)
            if left < mbin < right:
                matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms

    return matrix


# --- Start of Refactored Audio Feature Extractor (to match Phi4M - Snippet A) ---
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):  # MODIFIED CLASS NAME AND __INIT__
    model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]

    def __init__(self,
                 audio_compression_rate: int = DEFAULT_COMPRESSION_RATE,  # ADDED DEFAULT
                 audio_downsample_rate: int = DEFAULT_QFORMER_RATE,  # ADDED DEFAULT (maps to qformer_rate)
                 audio_feat_stride: int = DEFAULT_FEAT_STRIDE,  # ADDED DEFAULT
                 feature_size: int = DEFAULT_N_MELS,  # Added default based on constants
                 sampling_rate: int = DEFAULT_SAMPLING_RATE,  # Added default based on constants
                 padding_value: float = 0.0,  # Added default
                 eightk_method: str = "fillzero",  # Added default for this custom param
                 **kwargs):

        # If feature_size, sampling_rate, padding_value are in kwargs, they will override defaults.
        # The super().__init__ expects feature_size, sampling_rate, padding_value.
        # We ensure they are passed, either from defaults or kwargs.
        _feature_size = kwargs.pop("feature_size", feature_size)
        _sampling_rate = kwargs.pop("sampling_rate", sampling_rate)
        _padding_value = kwargs.pop("padding_value", padding_value)

        super().__init__(feature_size=_feature_size, sampling_rate=_sampling_rate, padding_value=_padding_value,
                         **kwargs)

        self.compression_rate = audio_compression_rate
        self.qformer_compression_rate = audio_downsample_rate
        self.feat_stride = audio_feat_stride

        self._eightk_method = eightk_method  # Use the argument, which has a default

        # Ensure _sampling_rate is used for mel filterbank if it was overridden by kwargs for superclass
        # However, Phi4M logic hardcodes 16000Hz for its mel parameters.
        # self.sampling_rate from super() will be the target sampling rate.
        if self.sampling_rate != 16000:
            logger.warning(
                f"The feature extractor's target sampling rate is {self.sampling_rate}, "
                "but Phi4M-consistent Mel parameters are based on 16000 Hz. "
                "This might lead to inconsistencies if the input audio is not resampled to 16000 Hz by this extractor."
            )

        self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
        self._hamming400 = np.hamming(400)
        self._hamming200 = np.hamming(200)

    def __call__(
            self,
            audios: List[Union[AudioInput, Tuple[np.ndarray, int]]],
            return_tensors: Optional[Union[str, TensorType]] = None,
            # sampling_rate: Optional[int] = None, # This was in original B, but Phi4M gets sr from AudioInput
    ):
        returned_input_audio_embeds = []
        returned_audio_embed_sizes = []
        audio_frames_list = []

        for audio_input_item in audios:
            if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2:
                raise ValueError(
                    "Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)."
                )
            audio_data, sample_rate = audio_input_item  # sample_rate is from the input audio

            if isinstance(audio_data, list):
                audio_data = np.array(audio_data, dtype=np.float32)
            if not isinstance(audio_data, np.ndarray):
                raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")

            # _extract_features will handle resampling to self.sampling_rate (16000 Hz)
            audio_embeds_np = self._extract_features(audio_data, sample_rate)

            num_mel_frames = audio_embeds_np.shape[0]
            current_audio_frames = num_mel_frames * self.feat_stride

            audio_embed_size = self._compute_audio_embed_size(current_audio_frames)

            returned_input_audio_embeds.append(torch.from_numpy(audio_embeds_np))
            returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
            audio_frames_list.append(current_audio_frames)

        padded_input_audio_embeds = pad_sequence(
            returned_input_audio_embeds, batch_first=True, padding_value=self.padding_value
        )
        stacked_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)

        tensor_audio_frames_list = torch.tensor(audio_frames_list, dtype=torch.long)

        max_audio_frames = 0
        if len(audios) > 0 and tensor_audio_frames_list.numel() > 0:
            max_audio_frames = tensor_audio_frames_list.max().item()

        returned_audio_attention_mask = None
        if max_audio_frames > 0:
            if len(audios) > 1:
                returned_audio_attention_mask = torch.arange(0, max_audio_frames,
                                                             device=tensor_audio_frames_list.device).unsqueeze(
                    0) < tensor_audio_frames_list.unsqueeze(1)
            elif len(audios) == 1:
                returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool,
                                                           device=tensor_audio_frames_list.device)

        data = {
            "input_audio_embeds": padded_input_audio_embeds,
            "audio_embed_sizes": stacked_audio_embed_sizes,
        }
        if returned_audio_attention_mask is not None:
            data["audio_attention_mask"] = returned_audio_attention_mask

        return BatchFeature(data=data, tensor_type=return_tensors)

    def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
        # This method expects fs to be the original sampling rate of wav.
        # It will resample to self.sampling_rate (16000Hz) or 8000Hz as needed.
        if wav.ndim > 1:
            wav = np.squeeze(wav)
        if len(wav.shape) == 2:
            wav = wav.mean(axis=1).astype(np.float32)

        wav = wav.astype(np.float32)

        current_fs = fs
        if current_fs > self.sampling_rate:  # self.sampling_rate is 16000
            wav = scipy.signal.resample_poly(wav, self.sampling_rate, current_fs)
            current_fs = self.sampling_rate
        elif 8000 < current_fs < self.sampling_rate:
            wav = scipy.signal.resample_poly(wav, 8000, current_fs)
            current_fs = 8000
        elif current_fs < 8000 and current_fs > 0:
            logger.warning(f"Sample rate {current_fs} is less than 8000Hz. Resampling to 8000Hz.")
            wav = scipy.signal.resample_poly(wav, 8000, current_fs)
            current_fs = 8000
        elif current_fs <= 0:
            raise RuntimeError(f"Unsupported sample rate {current_fs}")

        # After this block, current_fs is either 16000Hz or 8000Hz, or an error was raised.
        # Or it's the original fs if it was already 16000 or 8000.

        if current_fs == 8000:
            if self._eightk_method == "resample":
                wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000)
                current_fs = self.sampling_rate
        elif current_fs != self.sampling_rate:
            # This case should ideally not be hit if logic above is correct and self.sampling_rate is 16000
            raise RuntimeError(
                f"Audio sample rate {current_fs} not supported. Expected {self.sampling_rate} or 8000 for 8k methods.")

        preemphasis_coeff = 0.97

        # current_fs is now the rate for STFT parameters (either 16000 or 8000 if fillzero)
        if current_fs == 8000:  # This implies _eightk_method == "fillzero"
            n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
        elif current_fs == 16000:  # This is the standard path
            n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
        else:
            raise RuntimeError(f"Inconsistent fs {current_fs} for parameter selection. Should be 16000 or 8000.")

        if len(wav) < win_length:
            wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))

        num_frames = (wav.shape[0] - win_length) // hop_length + 1
        if num_frames <= 0:
            # For n_fft=512 (16k), output bins = 257. For n_fft=256 (8k), output bins = 129
            # If fillzero for 8k, it will be padded to 257 later.
            # So, the number of freq bins depends on n_fft here.
            return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)

        y_frames = np.array(
            [wav[i * hop_length: i * hop_length + win_length] for i in range(num_frames)],
            dtype=np.float32,
        )

        _y_frames_rolled = np.roll(y_frames, 1, axis=1)
        _y_frames_rolled[:, 0] = _y_frames_rolled[:, 1]
        y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0

        S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)

        if current_fs == 8000 and self._eightk_method == "fillzero":
            # S.shape[1] is 129 for n_fft=256. Target is 257 for n_fft=512 equivalence.
            target_bins = (512 // 2) + 1
            S_core = S[:, :-1]  # Drop 8kHz Nyquist bin (1 bin)
            # Pad to target_bins. Number of columns to add: target_bins - S_core.shape[1]
            padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
            S = np.concatenate((S_core, padarray), axis=1)

        spec = np.abs(S).astype(np.float32)
        return spec

    def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
        spec = self._extract_spectrogram(wav, fs)
        if spec.shape[0] == 0:
            # self.feature_size is n_mels (e.g. 80)
            return np.zeros((0, self.feature_size), dtype=np.float32)

        spec_power = spec ** 2
        fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
        log_fbank = np.log(fbank_power).astype(np.float32)
        return log_fbank

    def _compute_audio_embed_size(self, audio_frames: int) -> int:
        integer = audio_frames // self.compression_rate
        remainder = audio_frames % self.compression_rate
        result = integer if remainder == 0 else integer + 1

        integer = result // self.qformer_compression_rate
        remainder = result % self.qformer_compression_rate
        result = integer if remainder == 0 else integer + 1
        return result


class Gemma3ImagesKwargs(ImagesKwargs):
    do_pan_and_scan: Optional[bool]
    pan_and_scan_min_crop_size: Optional[int]
    pan_and_scan_max_num_crops: Optional[int]
    pan_and_scan_min_ratio_to_activate: Optional[float]
    do_convert_rgb: Optional[bool]


class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
    images_kwargs: Optional[Dict[str, Any]] = None
    audio_kwargs: Optional[Dict[str, Any]] = None
    text_kwargs: Optional[Dict[str, Any]] = None
    _defaults = {
        "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
        "images_kwargs": {},
        "audio_kwargs": {}
    }


class Gemma3OmniProcessor(ProcessorMixin):
    attributes = ["image_processor", "audio_processor", "tokenizer"]
    valid_kwargs = ["chat_template", "image_seq_length"]

    image_processor_class = "AutoImageProcessor"
    audio_processor_class = "AutoFeatureExtractor"
    tokenizer_class = "AutoTokenizer"

    def __init__(
            self,
            image_processor=None,
            audio_processor=None,  # User can pass an instance of RefactoredGemma3... here
            tokenizer=None,
            chat_template=None,
            image_seq_length: int = 256,
            **kwargs
    ):
        super().__init__(
            image_processor=image_processor,
            audio_processor=audio_processor,
            tokenizer=tokenizer,
            chat_template=chat_template,
            **kwargs
        )

        self.image_seq_length = image_seq_length
        if self.tokenizer is not None:
            self.image_token_id = getattr(self.tokenizer, "image_token_id",
                                          self.tokenizer.unk_token_id if hasattr(self.tokenizer,
                                                                                 "unk_token_id") else None)
            self.boi_token = getattr(self.tokenizer, "boi_token", "<image>")
            self.image_token = getattr(self.tokenizer, "image_token", "<image>")
            self.eoi_token = getattr(self.tokenizer, "eoi_token", "")

            self.audio_token_str_from_user_code = "<audio_soft_token>"  # Example
            # Ensure this token is actually in the tokenizer vocab as a special token
            self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
            if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
                logger.warning(
                    f"The audio token string '{self.audio_token_str_from_user_code}' maps to the UNK token. "
                    "Please ensure it is added to the tokenizer's vocabulary as a special token."
                )
            self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token}\n\n"
        else:
            logger.error(
                "Gemma3OmniProcessor initialized, but self.tokenizer is None. Token-dependent attributes will use placeholders or defaults.")
            self.image_token_id = None
            self.boi_token = "<image>"
            self.image_token = "<image>"
            self.eoi_token = ""
            self.audio_token_str_from_user_code = "<audio_soft_token>"
            self.audio_token_id = -1  # Placeholder if tokenizer is missing
            self.full_image_sequence = ""

        # These attributes are specific to Gemma3OmniProcessor for its internal _compute_audio_embed_size
        self.prompt_audio_compression_rate = kwargs.pop("prompt_audio_compression_rate", DEFAULT_COMPRESSION_RATE)
        self.prompt_audio_qformer_rate = kwargs.pop("prompt_audio_qformer_rate", DEFAULT_QFORMER_RATE)
        # self.prompt_audio_feat_stride = kwargs.pop("prompt_audio_feat_stride", DEFAULT_FEAT_STRIDE) # Not used by its _compute_audio_embed_size

        self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>")

    def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
        final_kwargs = {}
        _defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
        if not isinstance(_defaults, dict): _defaults = {}

        for modality_key, default_modality_kwargs in _defaults.items():
            final_kwargs[modality_key] = default_modality_kwargs.copy()

        for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
            if modality_key_in_call in final_kwargs:
                if isinstance(modality_kwargs_in_call, dict):
                    final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
            elif isinstance(modality_kwargs_in_call, dict):  # New modality not in defaults
                final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()

        if self.tokenizer:  # Ensure tokenizer exists before accessing its attributes
            for modality_key in final_kwargs:
                modality_dict = final_kwargs[modality_key]
                if isinstance(modality_dict, dict):  # Check if it's a dictionary
                    for key_in_mod_dict in list(modality_dict.keys()):  # Iterate over keys
                        if key_in_mod_dict in tokenizer_init_kwargs:
                            value = (
                                getattr(self.tokenizer, key_in_mod_dict)
                                if hasattr(self.tokenizer, key_in_mod_dict)
                                else tokenizer_init_kwargs[key_in_mod_dict]
                            )
                            modality_dict[key_in_mod_dict] = value

        if "text_kwargs" not in final_kwargs: final_kwargs["text_kwargs"] = {}  # Ensure text_kwargs exists
        final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
        final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)

        return final_kwargs

    def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
        integer = audio_mel_frames // self.prompt_audio_compression_rate
        remainder = audio_mel_frames % self.prompt_audio_compression_rate
        result = integer if remainder == 0 else integer + 1

        # Second compression
        integer = result // self.prompt_audio_qformer_rate
        remainder = result % self.prompt_audio_qformer_rate
        result = integer if remainder == 0 else integer + 1
        return result

    def __call__(
            self,
            text: Union[str, List[str]] = None,
            images: Optional[Any] = None,
            audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
            sampling_rate: Optional[int] = None,  # sampling_rate for raw audio arrays
            return_tensors: Optional[Union[str, TensorType]] = None,
            **kwargs: Any
    ) -> BatchFeature:
        if text is None and images is None and audios is None:
            raise ValueError("Provide at least one of `text`, `images`, or `audios`.")

        final_rt = return_tensors  # Store original return_tensors
        # Properly merge kwargs for text, images, audio
        merged_call_kwargs = self._merge_kwargs(
            Gemma3ProcessorKwargs,  # The class defining _defaults
            self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},  # Tokenizer defaults
            **kwargs  # User-provided kwargs from the call
        )

        # Determine final return_tensors, prioritizing call > text_kwargs > default
        if final_rt is None:  # If not specified in call
            final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
        else:  # If specified in call, remove from text_kwargs to avoid conflict
            merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)

        if text is None:  # If no text, create empty strings based on other inputs
            num_samples = 0
            if images is not None:
                _images_list = images if isinstance(images, list) and (
                        not images or not isinstance(images[0], (int, float))) else [images]
                num_samples = len(_images_list)
            elif audios is not None:
                _audios_list = audios if isinstance(audios, list) and not (
                        isinstance(audios[0], tuple) and isinstance(audios[0][0], (int, float))) else [
                    audios]  # check if audios is list of items or list of (wave,sr)
                num_samples = len(_audios_list)
            text = [""] * num_samples if num_samples > 0 else [""]  # Default to one empty string if no inputs

        if isinstance(text, str): text = [text]  # Ensure text is a list
        if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
            raise ValueError("Input `text` must be a string or a list of strings.")

        image_features_dict = {}
        if images is not None:
            if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.")
            # Ensure images are correctly batched
            batched_images = make_nested_list_of_images(images)  # handles various image input types

            _img_kwargs = merged_call_kwargs.get("images_kwargs", {})
            _img_proc_output = self.image_processor(batched_images, return_tensors=None,
                                                    **_img_kwargs)  # Pass None to handle tensors later
            image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
                                                                      BatchFeature) else _img_proc_output

            if len(text) == 1 and text[0] == "" and len(
                    batched_images) > 0:  # If text is default empty and images exist
                text = [" ".join([self.boi_token] * len(img_batch)) for img_batch in batched_images]
            elif len(batched_images) != len(text):  # If text was provided, ensure consistency
                raise ValueError(
                    f"Inconsistent batch: {len(batched_images)} image groups, {len(text)} texts. Ensure one text prompt per image group."
                )

            num_crops_popped = image_features_dict.pop("num_crops", None)
            if num_crops_popped is not None:
                num_crops_all = to_py_obj(num_crops_popped)
                temp_text_img, current_crop_idx_offset = [], 0
                for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
                    crops_for_this_batch_sample = []  # Number of *additional* crops for each original image
                    if num_crops_all:  # If num_crops_all is not None or empty
                        for _ in current_imgs_in_batch:  # For each original image in the current batch sample
                            if current_crop_idx_offset < len(num_crops_all):
                                # num_crops_all contains total items (original + crops) for each image
                                # We need number of *additional* crops. Assuming num_crops_all[i] >= 1
                                crops_for_this_batch_sample.append(max(0, num_crops_all[current_crop_idx_offset] - 1))
                                current_crop_idx_offset += 1
                            else:
                                crops_for_this_batch_sample.append(0)  # Should not happen if num_crops_all is correct

                    image_placeholders_in_prompt = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
                    processed_prompt = prompt

                    # Iterate backwards to preserve indices for replacement
                    iter_count = min(len(crops_for_this_batch_sample), len(image_placeholders_in_prompt))
                    for i_placeholder_idx in range(iter_count - 1, -1, -1):
                        num_additional_crops_for_this_image = crops_for_this_batch_sample[i_placeholder_idx]
                        original_token_idx_in_prompt = image_placeholders_in_prompt[i_placeholder_idx]

                        if num_additional_crops_for_this_image > 0:
                            # Create replacement text: original image placeholder + placeholders for additional crops
                            replacement_text = self.boi_token + "".join(
                                [self.boi_token] * num_additional_crops_for_this_image)
                            # Replace the single original boi_token with the new sequence
                            processed_prompt = (
                                    processed_prompt[:original_token_idx_in_prompt] +
                                    replacement_text +
                                    processed_prompt[original_token_idx_in_prompt + len(self.boi_token):]
                            )
                    temp_text_img.append(processed_prompt)
                text = temp_text_img
            # Replace all BOI tokens with the full image sequence (BOI + IMAGE*N + EOI)
            # This step assumes that if additional crops were handled, self.boi_token still marks each image.
            text = [p.replace(self.boi_token, self.full_image_sequence) for p in text]

        audio_features_dict = {}
        if audios is not None:
            if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.")

            audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
            # Pass sampling_rate from __call__ to audio_processor if provided (for raw arrays)
            if sampling_rate is not None: audio_call_kwargs["sampling_rate"] = sampling_rate

            # The audio_processor (e.g., RefactoredGemma3...) will return its model_input_names
            # e.g., {"input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"}
            _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
            audio_features_dict = _audio_proc_output.data

            new_text_with_audio = []

            # Determine the number of actual audio items processed by the audio_processor
            # This should match len(text) if batching is consistent.
            # The 'audio_attention_mask' or 'input_audio_embeds' can indicate this.
            num_audio_samples_processed = audio_features_dict[self.audio_processor.model_input_names[0]].shape[0]

            if num_audio_samples_processed != len(text):
                raise ValueError(
                    f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts."
                )
            frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum(
                axis=-1))  # sum of audio_attention_mask

            for i, prompt in enumerate(text):
                # num_soft_tokens should be the final number of audio tokens to insert in the text.
                # This is calculated by the Gemma3OmniProcessor's own method.
                num_soft_tokens = self._compute_audio_embed_size(frames_for_embed_size_calc[i])

                audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens

                if self.audio_placeholder_token in prompt:
                    prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str,
                                            1)  # Replace only first
                else:
                    prompt += audio_token_sequence_str  # Append if no placeholder
                new_text_with_audio.append(prompt)
            text = new_text_with_audio

        text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
        text_features_dict = self.tokenizer(text=text, return_tensors=None,
                                            **text_tokenizer_kwargs)  # Pass None for tensors

        # Create token_type_ids
        input_ids_list_of_lists = text_features_dict["input_ids"]
        # Ensure it's a list of lists
        if not isinstance(input_ids_list_of_lists, list) or not (
                input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)):
            if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
                input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)  # to nested python lists
            elif isinstance(input_ids_list_of_lists, list) and (
                    not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
                input_ids_list_of_lists = [input_ids_list_of_lists]  # wrap single list

        token_type_ids_list = []
        for ids_sample in input_ids_list_of_lists:
            types = [0] * len(ids_sample)  # 0 for text
            for j, token_id_val in enumerate(ids_sample):
                if self.image_token_id is not None and token_id_val == self.image_token_id:
                    types[j] = 1  # 1 for image
                elif self.audio_token_id != -1 and token_id_val == self.audio_token_id:  # Check if audio_token_id is valid
                    types[j] = 2  # 2 for audio
            token_type_ids_list.append(types)
        text_features_dict["token_type_ids"] = token_type_ids_list

        final_batch_data = {**text_features_dict}
        if image_features_dict: final_batch_data.update(image_features_dict)
        if audio_features_dict: final_batch_data.update(audio_features_dict)

        # Convert all data to tensors if final_rt is specified
        return BatchFeature(data=final_batch_data, tensor_type=final_rt)

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self) -> List[str]:
        input_names = set()
        if hasattr(self, 'tokenizer') and self.tokenizer is not None:
            # Make sure model_input_names is a list/set before +
            tokenizer_inputs = self.tokenizer.model_input_names
            if isinstance(tokenizer_inputs, (list, set)):
                input_names.update(tokenizer_inputs)
            else:  # Fallback if it's a single string
                input_names.add(str(tokenizer_inputs))
            input_names.add("token_type_ids")

        if hasattr(self, 'image_processor') and self.image_processor is not None:
            # Similar check for image_processor
            image_inputs = self.image_processor.model_input_names
            if isinstance(image_inputs, (list, set)):
                input_names.update(image_inputs)
            else:
                input_names.add(str(image_inputs))

        if hasattr(self, 'audio_processor') and self.audio_processor is not None:
            # Use model_input_names from the instantiated audio_processor
            # This will correctly reflect the names from RefactoredGemma3... if it's used.
            audio_inputs = self.audio_processor.model_input_names
            if isinstance(audio_inputs, (list, set)):
                input_names.update(audio_inputs)
            else:
                input_names.add(str(audio_inputs))

        return list(input_names)