Update processing_gemma3_omni.py
Browse files- processing_gemma3_omni.py +9 -41
processing_gemma3_omni.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import re
|
2 |
from typing import List, Optional, Union, Dict, Any, Tuple # Added Tuple
|
3 |
|
4 |
-
import math
|
5 |
import numpy as np
|
6 |
import scipy.signal
|
7 |
import torch
|
@@ -19,15 +18,16 @@ DEFAULT_N_FFT = 512
|
|
19 |
DEFAULT_WIN_LENGTH = 400
|
20 |
DEFAULT_HOP_LENGTH = 160
|
21 |
DEFAULT_N_MELS = 80
|
22 |
-
DEFAULT_COMPRESSION_RATE = 4
|
23 |
-
DEFAULT_QFORMER_RATE = 2
|
24 |
-
DEFAULT_FEAT_STRIDE = 4
|
25 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
26 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
27 |
DEFAULT_MAX_LENGTH = 16384
|
28 |
|
29 |
logger = logging.get_logger(__name__)
|
30 |
|
|
|
31 |
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
32 |
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
|
33 |
Args:
|
@@ -283,6 +283,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): # MODIFIED CLASS N
|
|
283 |
return log_fbank
|
284 |
|
285 |
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
|
|
286 |
integer = audio_frames // self.compression_rate
|
287 |
remainder = audio_frames % self.compression_rate
|
288 |
result = integer if remainder == 0 else integer + 1
|
@@ -293,14 +294,6 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): # MODIFIED CLASS N
|
|
293 |
return result
|
294 |
|
295 |
|
296 |
-
# The rest of your script (Gemma3ImagesKwargs, Gemma3ProcessorKwargs, Gemma3OmniProcessor) follows...
|
297 |
-
# Make sure this Gemma3AudioFeatureExtractor class replaces the old one or
|
298 |
-
# is correctly registered/named if your AutoProcessor setup relies on a specific name.
|
299 |
-
|
300 |
-
|
301 |
-
# --- End of Refactored Audio Feature Extractor ---
|
302 |
-
|
303 |
-
|
304 |
class Gemma3ImagesKwargs(ImagesKwargs):
|
305 |
do_pan_and_scan: Optional[bool]
|
306 |
pan_and_scan_min_crop_size: Optional[int]
|
@@ -416,23 +409,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
416 |
return final_kwargs
|
417 |
|
418 |
def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
|
419 |
-
|
420 |
-
# It calculates a number of soft tokens based on its own compression rates.
|
421 |
-
# Note: `audio_mel_frames` here is the number of raw Mel frames from the feature extractor's perspective
|
422 |
-
# if the attention mask sum is directly used before feat_stride scaling by the processor.
|
423 |
-
# However, if using the Refactored processor, audio_attention_mask.sum() will yield
|
424 |
-
# num_mel_frames * feat_stride. This method should then correctly compress that value.
|
425 |
-
|
426 |
-
# Using prompt_audio_compression_rate and prompt_audio_qformer_rate
|
427 |
-
# which are attributes of this Gemma3OmniProcessor class.
|
428 |
-
|
429 |
-
# First compression
|
430 |
-
# audio_mel_frames here should ideally be num_actual_mel_frames * feat_stride_of_the_audio_processor
|
431 |
-
# if trying to match the number of tokens from a Phi4M-style processor.
|
432 |
-
# The refactored audio processor does this scaling internally before its own _compute_audio_embed_size.
|
433 |
-
# If actual_mel_frames_per_sample (from sum of attention_mask) *is* already scaled by feat_stride
|
434 |
-
# (as it would be if using the refactored processor's attention_mask), then this calculation is correct.
|
435 |
-
|
436 |
integer = audio_mel_frames // self.prompt_audio_compression_rate
|
437 |
remainder = audio_mel_frames % self.prompt_audio_compression_rate
|
438 |
result = integer if remainder == 0 else integer + 1
|
@@ -473,11 +450,11 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
473 |
num_samples = 0
|
474 |
if images is not None:
|
475 |
_images_list = images if isinstance(images, list) and (
|
476 |
-
|
477 |
num_samples = len(_images_list)
|
478 |
elif audios is not None:
|
479 |
_audios_list = audios if isinstance(audios, list) and not (
|
480 |
-
|
481 |
audios] # check if audios is list of items or list of (wave,sr)
|
482 |
num_samples = len(_audios_list)
|
483 |
text = [""] * num_samples if num_samples > 0 else [""] # Default to one empty string if no inputs
|
@@ -571,15 +548,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
571 |
raise ValueError(
|
572 |
f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts."
|
573 |
)
|
574 |
-
|
575 |
-
# If using Gemma3AudioFeatureExtractor,
|
576 |
-
# "audio_embed_sizes" is already computed correctly (num compressed tokens).
|
577 |
-
# The processor's own _compute_audio_embed_size is called to determine how many
|
578 |
-
# self.audio_token_str_from_user_code to insert. Ideally, this matches.
|
579 |
-
|
580 |
-
# Get the number of frames that the processor's _compute_audio_embed_size expects.
|
581 |
-
# If the audio_processor is RefactoredGemma3..., its attention_mask is over (num_mel_frames * feat_stride).
|
582 |
-
# So, sum of that mask gives the input for this processor's _compute_audio_embed_size.
|
583 |
frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum(
|
584 |
axis=-1)) # sum of audio_attention_mask
|
585 |
|
@@ -666,4 +634,4 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
666 |
else:
|
667 |
input_names.add(str(audio_inputs))
|
668 |
|
669 |
-
return list(input_names)
|
|
|
1 |
import re
|
2 |
from typing import List, Optional, Union, Dict, Any, Tuple # Added Tuple
|
3 |
|
|
|
4 |
import numpy as np
|
5 |
import scipy.signal
|
6 |
import torch
|
|
|
18 |
DEFAULT_WIN_LENGTH = 400
|
19 |
DEFAULT_HOP_LENGTH = 160
|
20 |
DEFAULT_N_MELS = 80
|
21 |
+
DEFAULT_COMPRESSION_RATE = 4 # Used for default in __init__
|
22 |
+
DEFAULT_QFORMER_RATE = 2 # Used for default in __init__ (as audio_downsample_rate)
|
23 |
+
DEFAULT_FEAT_STRIDE = 4 # Used for default in __init__
|
24 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
25 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
26 |
DEFAULT_MAX_LENGTH = 16384
|
27 |
|
28 |
logger = logging.get_logger(__name__)
|
29 |
|
30 |
+
|
31 |
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
32 |
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
|
33 |
Args:
|
|
|
283 |
return log_fbank
|
284 |
|
285 |
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
286 |
+
print("self.compression_rate", self.compression_rate)
|
287 |
integer = audio_frames // self.compression_rate
|
288 |
remainder = audio_frames % self.compression_rate
|
289 |
result = integer if remainder == 0 else integer + 1
|
|
|
294 |
return result
|
295 |
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
class Gemma3ImagesKwargs(ImagesKwargs):
|
298 |
do_pan_and_scan: Optional[bool]
|
299 |
pan_and_scan_min_crop_size: Optional[int]
|
|
|
409 |
return final_kwargs
|
410 |
|
411 |
def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
|
412 |
+
print("prompt_audio_compression_rate", self.prompt_audio_compression_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
integer = audio_mel_frames // self.prompt_audio_compression_rate
|
414 |
remainder = audio_mel_frames % self.prompt_audio_compression_rate
|
415 |
result = integer if remainder == 0 else integer + 1
|
|
|
450 |
num_samples = 0
|
451 |
if images is not None:
|
452 |
_images_list = images if isinstance(images, list) and (
|
453 |
+
not images or not isinstance(images[0], (int, float))) else [images]
|
454 |
num_samples = len(_images_list)
|
455 |
elif audios is not None:
|
456 |
_audios_list = audios if isinstance(audios, list) and not (
|
457 |
+
isinstance(audios[0], tuple) and isinstance(audios[0][0], (int, float))) else [
|
458 |
audios] # check if audios is list of items or list of (wave,sr)
|
459 |
num_samples = len(_audios_list)
|
460 |
text = [""] * num_samples if num_samples > 0 else [""] # Default to one empty string if no inputs
|
|
|
548 |
raise ValueError(
|
549 |
f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts."
|
550 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum(
|
552 |
axis=-1)) # sum of audio_attention_mask
|
553 |
|
|
|
634 |
else:
|
635 |
input_names.add(str(audio_inputs))
|
636 |
|
637 |
+
return list(input_names)
|