voidful commited on
Commit
9e58a2b
·
verified ·
1 Parent(s): 701891b

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +9 -41
processing_gemma3_omni.py CHANGED
@@ -1,7 +1,6 @@
1
  import re
2
  from typing import List, Optional, Union, Dict, Any, Tuple # Added Tuple
3
 
4
- import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
@@ -19,15 +18,16 @@ DEFAULT_N_FFT = 512
19
  DEFAULT_WIN_LENGTH = 400
20
  DEFAULT_HOP_LENGTH = 160
21
  DEFAULT_N_MELS = 80
22
- DEFAULT_COMPRESSION_RATE = 4 # Used for default in __init__
23
- DEFAULT_QFORMER_RATE = 2 # Used for default in __init__ (as audio_downsample_rate)
24
- DEFAULT_FEAT_STRIDE = 4 # Used for default in __init__
25
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
  DEFAULT_MAX_LENGTH = 16384
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
31
  def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
32
  """Create a Mel filter-bank the same as SpeechLib FbankFC.
33
  Args:
@@ -283,6 +283,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): # MODIFIED CLASS N
283
  return log_fbank
284
 
285
  def _compute_audio_embed_size(self, audio_frames: int) -> int:
 
286
  integer = audio_frames // self.compression_rate
287
  remainder = audio_frames % self.compression_rate
288
  result = integer if remainder == 0 else integer + 1
@@ -293,14 +294,6 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): # MODIFIED CLASS N
293
  return result
294
 
295
 
296
- # The rest of your script (Gemma3ImagesKwargs, Gemma3ProcessorKwargs, Gemma3OmniProcessor) follows...
297
- # Make sure this Gemma3AudioFeatureExtractor class replaces the old one or
298
- # is correctly registered/named if your AutoProcessor setup relies on a specific name.
299
-
300
-
301
- # --- End of Refactored Audio Feature Extractor ---
302
-
303
-
304
  class Gemma3ImagesKwargs(ImagesKwargs):
305
  do_pan_and_scan: Optional[bool]
306
  pan_and_scan_min_crop_size: Optional[int]
@@ -416,23 +409,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
416
  return final_kwargs
417
 
418
  def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
419
- # This method is part of Gemma3OmniProcessor.
420
- # It calculates a number of soft tokens based on its own compression rates.
421
- # Note: `audio_mel_frames` here is the number of raw Mel frames from the feature extractor's perspective
422
- # if the attention mask sum is directly used before feat_stride scaling by the processor.
423
- # However, if using the Refactored processor, audio_attention_mask.sum() will yield
424
- # num_mel_frames * feat_stride. This method should then correctly compress that value.
425
-
426
- # Using prompt_audio_compression_rate and prompt_audio_qformer_rate
427
- # which are attributes of this Gemma3OmniProcessor class.
428
-
429
- # First compression
430
- # audio_mel_frames here should ideally be num_actual_mel_frames * feat_stride_of_the_audio_processor
431
- # if trying to match the number of tokens from a Phi4M-style processor.
432
- # The refactored audio processor does this scaling internally before its own _compute_audio_embed_size.
433
- # If actual_mel_frames_per_sample (from sum of attention_mask) *is* already scaled by feat_stride
434
- # (as it would be if using the refactored processor's attention_mask), then this calculation is correct.
435
-
436
  integer = audio_mel_frames // self.prompt_audio_compression_rate
437
  remainder = audio_mel_frames % self.prompt_audio_compression_rate
438
  result = integer if remainder == 0 else integer + 1
@@ -473,11 +450,11 @@ class Gemma3OmniProcessor(ProcessorMixin):
473
  num_samples = 0
474
  if images is not None:
475
  _images_list = images if isinstance(images, list) and (
476
- not images or not isinstance(images[0], (int, float))) else [images]
477
  num_samples = len(_images_list)
478
  elif audios is not None:
479
  _audios_list = audios if isinstance(audios, list) and not (
480
- isinstance(audios[0], tuple) and isinstance(audios[0][0], (int, float))) else [
481
  audios] # check if audios is list of items or list of (wave,sr)
482
  num_samples = len(_audios_list)
483
  text = [""] * num_samples if num_samples > 0 else [""] # Default to one empty string if no inputs
@@ -571,15 +548,6 @@ class Gemma3OmniProcessor(ProcessorMixin):
571
  raise ValueError(
572
  f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts."
573
  )
574
-
575
- # If using Gemma3AudioFeatureExtractor,
576
- # "audio_embed_sizes" is already computed correctly (num compressed tokens).
577
- # The processor's own _compute_audio_embed_size is called to determine how many
578
- # self.audio_token_str_from_user_code to insert. Ideally, this matches.
579
-
580
- # Get the number of frames that the processor's _compute_audio_embed_size expects.
581
- # If the audio_processor is RefactoredGemma3..., its attention_mask is over (num_mel_frames * feat_stride).
582
- # So, sum of that mask gives the input for this processor's _compute_audio_embed_size.
583
  frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum(
584
  axis=-1)) # sum of audio_attention_mask
585
 
@@ -666,4 +634,4 @@ class Gemma3OmniProcessor(ProcessorMixin):
666
  else:
667
  input_names.add(str(audio_inputs))
668
 
669
- return list(input_names)
 
1
  import re
2
  from typing import List, Optional, Union, Dict, Any, Tuple # Added Tuple
3
 
 
4
  import numpy as np
5
  import scipy.signal
6
  import torch
 
18
  DEFAULT_WIN_LENGTH = 400
19
  DEFAULT_HOP_LENGTH = 160
20
  DEFAULT_N_MELS = 80
21
+ DEFAULT_COMPRESSION_RATE = 4 # Used for default in __init__
22
+ DEFAULT_QFORMER_RATE = 2 # Used for default in __init__ (as audio_downsample_rate)
23
+ DEFAULT_FEAT_STRIDE = 4 # Used for default in __init__
24
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
25
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
26
  DEFAULT_MAX_LENGTH = 16384
27
 
28
  logger = logging.get_logger(__name__)
29
 
30
+
31
  def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
32
  """Create a Mel filter-bank the same as SpeechLib FbankFC.
33
  Args:
 
283
  return log_fbank
284
 
285
  def _compute_audio_embed_size(self, audio_frames: int) -> int:
286
+ print("self.compression_rate", self.compression_rate)
287
  integer = audio_frames // self.compression_rate
288
  remainder = audio_frames % self.compression_rate
289
  result = integer if remainder == 0 else integer + 1
 
294
  return result
295
 
296
 
 
 
 
 
 
 
 
 
297
  class Gemma3ImagesKwargs(ImagesKwargs):
298
  do_pan_and_scan: Optional[bool]
299
  pan_and_scan_min_crop_size: Optional[int]
 
409
  return final_kwargs
410
 
411
  def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
412
+ print("prompt_audio_compression_rate", self.prompt_audio_compression_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  integer = audio_mel_frames // self.prompt_audio_compression_rate
414
  remainder = audio_mel_frames % self.prompt_audio_compression_rate
415
  result = integer if remainder == 0 else integer + 1
 
450
  num_samples = 0
451
  if images is not None:
452
  _images_list = images if isinstance(images, list) and (
453
+ not images or not isinstance(images[0], (int, float))) else [images]
454
  num_samples = len(_images_list)
455
  elif audios is not None:
456
  _audios_list = audios if isinstance(audios, list) and not (
457
+ isinstance(audios[0], tuple) and isinstance(audios[0][0], (int, float))) else [
458
  audios] # check if audios is list of items or list of (wave,sr)
459
  num_samples = len(_audios_list)
460
  text = [""] * num_samples if num_samples > 0 else [""] # Default to one empty string if no inputs
 
548
  raise ValueError(
549
  f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts."
550
  )
 
 
 
 
 
 
 
 
 
551
  frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum(
552
  axis=-1)) # sum of audio_attention_mask
553
 
 
634
  else:
635
  input_names.add(str(audio_inputs))
636
 
637
+ return list(input_names)