|
import re |
|
from typing import List, Optional, Union, Dict, Any, Tuple |
|
|
|
import numpy as np |
|
import scipy.signal |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers.audio_utils import AudioInput |
|
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_utils import make_nested_list_of_images |
|
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs |
|
from transformers.utils import TensorType, to_py_obj, logging |
|
|
|
|
|
DEFAULT_SAMPLING_RATE = 16000 |
|
DEFAULT_N_FFT = 512 |
|
DEFAULT_WIN_LENGTH = 400 |
|
DEFAULT_HOP_LENGTH = 160 |
|
DEFAULT_N_MELS = 80 |
|
DEFAULT_COMPRESSION_RATE = 4 |
|
DEFAULT_QFORMER_RATE = 4 |
|
DEFAULT_FEAT_STRIDE = 4 |
|
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>" |
|
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>" |
|
DEFAULT_MAX_LENGTH = 16384 |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): |
|
"""Create a Mel filter-bank the same as SpeechLib FbankFC. |
|
Args: |
|
sample_rate (int): Sample rate in Hz. number > 0 [scalar] |
|
n_fft (int): FFT size. int > 0 [scalar] |
|
n_mel (int): Mel filter size. int > 0 [scalar] |
|
fmin (float): lowest frequency (in Hz). If None use 0.0. |
|
float >= 0 [scalar] |
|
fmax: highest frequency (in Hz). If None use sample_rate / 2. |
|
float >= 0 [scalar] |
|
Returns |
|
out (numpy.ndarray): Mel transform matrix |
|
[shape=(n_mels, 1 + n_fft/2)] |
|
""" |
|
|
|
bank_width = int(n_fft // 2 + 1) |
|
if fmax is None: |
|
fmax = sample_rate / 2 |
|
if fmin is None: |
|
fmin = 0 |
|
assert fmin >= 0, "fmin cannot be negtive" |
|
assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]" |
|
|
|
def mel(f): |
|
return 1127.0 * np.log(1.0 + f / 700.0) |
|
|
|
def bin2mel(fft_bin): |
|
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) |
|
|
|
def f2bin(f): |
|
return int((f * n_fft / sample_rate) + 0.5) |
|
|
|
|
|
klo = f2bin(fmin) + 1 |
|
khi = f2bin(fmax) |
|
|
|
khi = max(khi, klo) |
|
|
|
|
|
mlo = mel(fmin) |
|
mhi = mel(fmax) |
|
m_centers = np.linspace(mlo, mhi, n_mels + 2) |
|
ms = (mhi - mlo) / (n_mels + 1) |
|
|
|
matrix = np.zeros((n_mels, bank_width), dtype=np.float32) |
|
for m in range(0, n_mels): |
|
left = m_centers[m] |
|
center = m_centers[m + 1] |
|
right = m_centers[m + 2] |
|
for fft_bin in range(klo, khi): |
|
mbin = bin2mel(fft_bin) |
|
if left < mbin < right: |
|
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms |
|
|
|
return matrix |
|
|
|
|
|
|
|
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): |
|
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"] |
|
|
|
def __init__(self, |
|
audio_compression_rate: int = DEFAULT_COMPRESSION_RATE, |
|
audio_downsample_rate: int = DEFAULT_QFORMER_RATE, |
|
audio_feat_stride: int = DEFAULT_FEAT_STRIDE, |
|
feature_size: int = DEFAULT_N_MELS, |
|
sampling_rate: int = DEFAULT_SAMPLING_RATE, |
|
padding_value: float = 0.0, |
|
eightk_method: str = "fillzero", |
|
**kwargs): |
|
|
|
|
|
|
|
|
|
_feature_size = kwargs.pop("feature_size", feature_size) |
|
_sampling_rate = kwargs.pop("sampling_rate", sampling_rate) |
|
_padding_value = kwargs.pop("padding_value", padding_value) |
|
|
|
super().__init__(feature_size=_feature_size, sampling_rate=_sampling_rate, padding_value=_padding_value, |
|
**kwargs) |
|
|
|
self.compression_rate = audio_compression_rate |
|
self.qformer_compression_rate = audio_downsample_rate |
|
self.feat_stride = audio_feat_stride |
|
|
|
self._eightk_method = eightk_method |
|
|
|
|
|
|
|
|
|
if self.sampling_rate != 16000: |
|
logger.warning( |
|
f"The feature extractor's target sampling rate is {self.sampling_rate}, " |
|
"but Phi4M-consistent Mel parameters are based on 16000 Hz. " |
|
"This might lead to inconsistencies if the input audio is not resampled to 16000 Hz by this extractor." |
|
) |
|
|
|
self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T |
|
self._hamming400 = np.hamming(400) |
|
self._hamming200 = np.hamming(200) |
|
|
|
def __call__( |
|
self, |
|
audios: List[Union[AudioInput, Tuple[np.ndarray, int]]], |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
|
): |
|
returned_input_audio_embeds = [] |
|
returned_audio_embed_sizes = [] |
|
audio_frames_list = [] |
|
|
|
for audio_input_item in audios: |
|
if not isinstance(audio_input_item, tuple) or len(audio_input_item) != 2: |
|
raise ValueError( |
|
"Each item in 'audios' must be a tuple (waveform: np.ndarray, sample_rate: int)." |
|
) |
|
audio_data, sample_rate = audio_input_item |
|
|
|
if isinstance(audio_data, list): |
|
audio_data = np.array(audio_data, dtype=np.float32) |
|
if not isinstance(audio_data, np.ndarray): |
|
raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}") |
|
|
|
|
|
audio_embeds_np = self._extract_features(audio_data, sample_rate) |
|
|
|
num_mel_frames = audio_embeds_np.shape[0] |
|
current_audio_frames = num_mel_frames * self.feat_stride |
|
|
|
audio_embed_size = self._compute_audio_embed_size(current_audio_frames) |
|
|
|
returned_input_audio_embeds.append(torch.from_numpy(audio_embeds_np)) |
|
returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long()) |
|
audio_frames_list.append(current_audio_frames) |
|
|
|
padded_input_audio_embeds = pad_sequence( |
|
returned_input_audio_embeds, batch_first=True, padding_value=self.padding_value |
|
) |
|
stacked_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0) |
|
|
|
tensor_audio_frames_list = torch.tensor(audio_frames_list, dtype=torch.long) |
|
|
|
max_audio_frames = 0 |
|
if len(audios) > 0 and tensor_audio_frames_list.numel() > 0: |
|
max_audio_frames = tensor_audio_frames_list.max().item() |
|
|
|
returned_audio_attention_mask = None |
|
if max_audio_frames > 0: |
|
if len(audios) > 1: |
|
returned_audio_attention_mask = torch.arange(0, max_audio_frames, |
|
device=tensor_audio_frames_list.device).unsqueeze( |
|
0) < tensor_audio_frames_list.unsqueeze(1) |
|
elif len(audios) == 1: |
|
returned_audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool, |
|
device=tensor_audio_frames_list.device) |
|
|
|
data = { |
|
"input_audio_embeds": padded_input_audio_embeds, |
|
"audio_embed_sizes": stacked_audio_embed_sizes, |
|
} |
|
if returned_audio_attention_mask is not None: |
|
data["audio_attention_mask"] = returned_audio_attention_mask |
|
|
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray: |
|
|
|
|
|
if wav.ndim > 1: |
|
wav = np.squeeze(wav) |
|
if len(wav.shape) == 2: |
|
wav = wav.mean(axis=1).astype(np.float32) |
|
|
|
wav = wav.astype(np.float32) |
|
|
|
current_fs = fs |
|
if current_fs > self.sampling_rate: |
|
wav = scipy.signal.resample_poly(wav, self.sampling_rate, current_fs) |
|
current_fs = self.sampling_rate |
|
elif 8000 < current_fs < self.sampling_rate: |
|
wav = scipy.signal.resample_poly(wav, 8000, current_fs) |
|
current_fs = 8000 |
|
elif current_fs < 8000 and current_fs > 0: |
|
logger.warning(f"Sample rate {current_fs} is less than 8000Hz. Resampling to 8000Hz.") |
|
wav = scipy.signal.resample_poly(wav, 8000, current_fs) |
|
current_fs = 8000 |
|
elif current_fs <= 0: |
|
raise RuntimeError(f"Unsupported sample rate {current_fs}") |
|
|
|
|
|
|
|
|
|
if current_fs == 8000: |
|
if self._eightk_method == "resample": |
|
wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000) |
|
current_fs = self.sampling_rate |
|
elif current_fs != self.sampling_rate: |
|
|
|
raise RuntimeError( |
|
f"Audio sample rate {current_fs} not supported. Expected {self.sampling_rate} or 8000 for 8k methods.") |
|
|
|
preemphasis_coeff = 0.97 |
|
|
|
|
|
if current_fs == 8000: |
|
n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200 |
|
elif current_fs == 16000: |
|
n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400 |
|
else: |
|
raise RuntimeError(f"Inconsistent fs {current_fs} for parameter selection. Should be 16000 or 8000.") |
|
|
|
if len(wav) < win_length: |
|
wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,)) |
|
|
|
num_frames = (wav.shape[0] - win_length) // hop_length + 1 |
|
if num_frames <= 0: |
|
|
|
|
|
|
|
return np.zeros((0, n_fft // 2 + 1), dtype=np.float32) |
|
|
|
y_frames = np.array( |
|
[wav[i * hop_length: i * hop_length + win_length] for i in range(num_frames)], |
|
dtype=np.float32, |
|
) |
|
|
|
_y_frames_rolled = np.roll(y_frames, 1, axis=1) |
|
_y_frames_rolled[:, 0] = _y_frames_rolled[:, 1] |
|
y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0 |
|
|
|
S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64) |
|
|
|
if current_fs == 8000 and self._eightk_method == "fillzero": |
|
|
|
target_bins = (512 // 2) + 1 |
|
S_core = S[:, :-1] |
|
|
|
padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype) |
|
S = np.concatenate((S_core, padarray), axis=1) |
|
|
|
spec = np.abs(S).astype(np.float32) |
|
return spec |
|
|
|
def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray: |
|
spec = self._extract_spectrogram(wav, fs) |
|
if spec.shape[0] == 0: |
|
|
|
return np.zeros((0, self.feature_size), dtype=np.float32) |
|
|
|
spec_power = spec ** 2 |
|
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) |
|
log_fbank = np.log(fbank_power).astype(np.float32) |
|
return log_fbank |
|
|
|
def _compute_audio_embed_size(self, audio_frames: int) -> int: |
|
integer = audio_frames // self.compression_rate |
|
remainder = audio_frames % self.compression_rate |
|
result = integer if remainder == 0 else integer + 1 |
|
|
|
integer = result // self.qformer_compression_rate |
|
remainder = result % self.qformer_compression_rate |
|
result = integer if remainder == 0 else integer + 1 |
|
return result |
|
|
|
|
|
class Gemma3ImagesKwargs(ImagesKwargs): |
|
do_pan_and_scan: Optional[bool] |
|
pan_and_scan_min_crop_size: Optional[int] |
|
pan_and_scan_max_num_crops: Optional[int] |
|
pan_and_scan_min_ratio_to_activate: Optional[float] |
|
do_convert_rgb: Optional[bool] |
|
|
|
|
|
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): |
|
images_kwargs: Optional[Dict[str, Any]] = None |
|
audio_kwargs: Optional[Dict[str, Any]] = None |
|
text_kwargs: Optional[Dict[str, Any]] = None |
|
_defaults = { |
|
"text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH}, |
|
"images_kwargs": {}, |
|
"audio_kwargs": {} |
|
} |
|
|
|
|
|
class Gemma3OmniProcessor(ProcessorMixin): |
|
attributes = ["image_processor", "audio_processor", "tokenizer"] |
|
valid_kwargs = ["chat_template", "image_seq_length"] |
|
|
|
image_processor_class = "AutoImageProcessor" |
|
audio_processor_class = "AutoFeatureExtractor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__( |
|
self, |
|
image_processor=None, |
|
audio_processor=None, |
|
tokenizer=None, |
|
chat_template=None, |
|
image_seq_length: int = 256, |
|
**kwargs |
|
): |
|
super().__init__( |
|
image_processor=image_processor, |
|
audio_processor=audio_processor, |
|
tokenizer=tokenizer, |
|
chat_template=chat_template, |
|
**kwargs |
|
) |
|
|
|
self.image_seq_length = image_seq_length |
|
if self.tokenizer is not None: |
|
self.image_token_id = getattr(self.tokenizer, "image_token_id", |
|
self.tokenizer.unk_token_id if hasattr(self.tokenizer, |
|
"unk_token_id") else None) |
|
self.boi_token = getattr(self.tokenizer, "boi_token", "<image>") |
|
self.image_token = getattr(self.tokenizer, "image_token", "<image>") |
|
self.eoi_token = getattr(self.tokenizer, "eoi_token", "") |
|
|
|
self.audio_token_str_from_user_code = "<audio_soft_token>" |
|
|
|
self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code) |
|
if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id: |
|
logger.warning( |
|
f"The audio token string '{self.audio_token_str_from_user_code}' maps to the UNK token. " |
|
"Please ensure it is added to the tokenizer's vocabulary as a special token." |
|
) |
|
self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token}\n\n" |
|
else: |
|
logger.error( |
|
"Gemma3OmniProcessor initialized, but self.tokenizer is None. Token-dependent attributes will use placeholders or defaults.") |
|
self.image_token_id = None |
|
self.boi_token = "<image>" |
|
self.image_token = "<image>" |
|
self.eoi_token = "" |
|
self.audio_token_str_from_user_code = "<audio_soft_token>" |
|
self.audio_token_id = -1 |
|
self.full_image_sequence = "" |
|
|
|
|
|
self.prompt_audio_compression_rate = kwargs.pop("prompt_audio_compression_rate", DEFAULT_COMPRESSION_RATE) |
|
self.prompt_audio_qformer_rate = kwargs.pop("prompt_audio_qformer_rate", DEFAULT_QFORMER_RATE) |
|
|
|
|
|
self.audio_placeholder_token = kwargs.pop("audio_placeholder_token", "<|audio_placeholder|>") |
|
|
|
def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call): |
|
final_kwargs = {} |
|
_defaults = getattr(KwargsClassWithDefaults, "_defaults", {}) |
|
if not isinstance(_defaults, dict): _defaults = {} |
|
|
|
for modality_key, default_modality_kwargs in _defaults.items(): |
|
final_kwargs[modality_key] = default_modality_kwargs.copy() |
|
|
|
for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items(): |
|
if modality_key_in_call in final_kwargs: |
|
if isinstance(modality_kwargs_in_call, dict): |
|
final_kwargs[modality_key_in_call].update(modality_kwargs_in_call) |
|
elif isinstance(modality_kwargs_in_call, dict): |
|
final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy() |
|
|
|
if self.tokenizer: |
|
for modality_key in final_kwargs: |
|
modality_dict = final_kwargs[modality_key] |
|
if isinstance(modality_dict, dict): |
|
for key_in_mod_dict in list(modality_dict.keys()): |
|
if key_in_mod_dict in tokenizer_init_kwargs: |
|
value = ( |
|
getattr(self.tokenizer, key_in_mod_dict) |
|
if hasattr(self.tokenizer, key_in_mod_dict) |
|
else tokenizer_init_kwargs[key_in_mod_dict] |
|
) |
|
modality_dict[key_in_mod_dict] = value |
|
|
|
if "text_kwargs" not in final_kwargs: final_kwargs["text_kwargs"] = {} |
|
final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False) |
|
final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH) |
|
|
|
return final_kwargs |
|
|
|
def _compute_audio_embed_size(self, audio_mel_frames: int) -> int: |
|
integer = audio_mel_frames // self.prompt_audio_compression_rate |
|
remainder = audio_mel_frames % self.prompt_audio_compression_rate |
|
result = integer if remainder == 0 else integer + 1 |
|
|
|
|
|
integer = result // self.prompt_audio_qformer_rate |
|
remainder = result % self.prompt_audio_qformer_rate |
|
result = integer if remainder == 0 else integer + 1 |
|
return result |
|
|
|
def __call__( |
|
self, |
|
text: Union[str, List[str]] = None, |
|
images: Optional[Any] = None, |
|
audios: Optional[Union[AudioInput, List[AudioInput]]] = None, |
|
sampling_rate: Optional[int] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
**kwargs: Any |
|
) -> BatchFeature: |
|
if text is None and images is None and audios is None: |
|
raise ValueError("Provide at least one of `text`, `images`, or `audios`.") |
|
|
|
final_rt = return_tensors |
|
|
|
merged_call_kwargs = self._merge_kwargs( |
|
Gemma3ProcessorKwargs, |
|
self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {}, |
|
**kwargs |
|
) |
|
|
|
|
|
if final_rt is None: |
|
final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH) |
|
else: |
|
merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None) |
|
|
|
if text is None: |
|
num_samples = 0 |
|
if images is not None: |
|
_images_list = images if isinstance(images, list) and ( |
|
not images or not isinstance(images[0], (int, float))) else [images] |
|
num_samples = len(_images_list) |
|
elif audios is not None: |
|
_audios_list = audios if isinstance(audios, list) and not ( |
|
isinstance(audios[0], tuple) and isinstance(audios[0][0], (int, float))) else [ |
|
audios] |
|
num_samples = len(_audios_list) |
|
text = [""] * num_samples if num_samples > 0 else [""] |
|
|
|
if isinstance(text, str): text = [text] |
|
if not (isinstance(text, list) and all(isinstance(t, str) for t in text)): |
|
raise ValueError("Input `text` must be a string or a list of strings.") |
|
|
|
image_features_dict = {} |
|
if images is not None: |
|
if self.image_processor is None: raise ValueError("Images provided but self.image_processor is None.") |
|
|
|
batched_images = make_nested_list_of_images(images) |
|
|
|
_img_kwargs = merged_call_kwargs.get("images_kwargs", {}) |
|
_img_proc_output = self.image_processor(batched_images, return_tensors=None, |
|
**_img_kwargs) |
|
image_features_dict = _img_proc_output.data if isinstance(_img_proc_output, |
|
BatchFeature) else _img_proc_output |
|
|
|
if len(text) == 1 and text[0] == "" and len( |
|
batched_images) > 0: |
|
text = [" ".join([self.boi_token] * len(img_batch)) for img_batch in batched_images] |
|
elif len(batched_images) != len(text): |
|
raise ValueError( |
|
f"Inconsistent batch: {len(batched_images)} image groups, {len(text)} texts. Ensure one text prompt per image group." |
|
) |
|
|
|
num_crops_popped = image_features_dict.pop("num_crops", None) |
|
if num_crops_popped is not None: |
|
num_crops_all = to_py_obj(num_crops_popped) |
|
temp_text_img, current_crop_idx_offset = [], 0 |
|
for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)): |
|
crops_for_this_batch_sample = [] |
|
if num_crops_all: |
|
for _ in current_imgs_in_batch: |
|
if current_crop_idx_offset < len(num_crops_all): |
|
|
|
|
|
crops_for_this_batch_sample.append(max(0, num_crops_all[current_crop_idx_offset] - 1)) |
|
current_crop_idx_offset += 1 |
|
else: |
|
crops_for_this_batch_sample.append(0) |
|
|
|
image_placeholders_in_prompt = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)] |
|
processed_prompt = prompt |
|
|
|
|
|
iter_count = min(len(crops_for_this_batch_sample), len(image_placeholders_in_prompt)) |
|
for i_placeholder_idx in range(iter_count - 1, -1, -1): |
|
num_additional_crops_for_this_image = crops_for_this_batch_sample[i_placeholder_idx] |
|
original_token_idx_in_prompt = image_placeholders_in_prompt[i_placeholder_idx] |
|
|
|
if num_additional_crops_for_this_image > 0: |
|
|
|
replacement_text = self.boi_token + "".join( |
|
[self.boi_token] * num_additional_crops_for_this_image) |
|
|
|
processed_prompt = ( |
|
processed_prompt[:original_token_idx_in_prompt] + |
|
replacement_text + |
|
processed_prompt[original_token_idx_in_prompt + len(self.boi_token):] |
|
) |
|
temp_text_img.append(processed_prompt) |
|
text = temp_text_img |
|
|
|
|
|
text = [p.replace(self.boi_token, self.full_image_sequence) for p in text] |
|
|
|
audio_features_dict = {} |
|
if audios is not None: |
|
if self.audio_processor is None: raise ValueError("Audios provided but self.audio_processor is None.") |
|
|
|
audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {}) |
|
|
|
if sampling_rate is not None: audio_call_kwargs["sampling_rate"] = sampling_rate |
|
|
|
|
|
|
|
_audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs) |
|
audio_features_dict = _audio_proc_output.data |
|
|
|
new_text_with_audio = [] |
|
|
|
|
|
|
|
|
|
num_audio_samples_processed = audio_features_dict[self.audio_processor.model_input_names[0]].shape[0] |
|
|
|
if num_audio_samples_processed != len(text): |
|
raise ValueError( |
|
f"Inconsistent batch for audio/text: {num_audio_samples_processed} audio samples processed, {len(text)} text prompts." |
|
) |
|
frames_for_embed_size_calc = to_py_obj(audio_features_dict[self.audio_processor.model_input_names[2]].sum( |
|
axis=-1)) |
|
|
|
for i, prompt in enumerate(text): |
|
|
|
|
|
num_soft_tokens = self._compute_audio_embed_size(frames_for_embed_size_calc[i]) |
|
|
|
audio_token_sequence_str = self.audio_token_str_from_user_code * num_soft_tokens |
|
|
|
if self.audio_placeholder_token in prompt: |
|
prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, |
|
1) |
|
else: |
|
prompt += audio_token_sequence_str |
|
new_text_with_audio.append(prompt) |
|
text = new_text_with_audio |
|
|
|
text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {}) |
|
text_features_dict = self.tokenizer(text=text, return_tensors=None, |
|
**text_tokenizer_kwargs) |
|
|
|
|
|
input_ids_list_of_lists = text_features_dict["input_ids"] |
|
|
|
if not isinstance(input_ids_list_of_lists, list) or not ( |
|
input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)): |
|
if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)): |
|
input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists) |
|
elif isinstance(input_ids_list_of_lists, list) and ( |
|
not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)): |
|
input_ids_list_of_lists = [input_ids_list_of_lists] |
|
|
|
token_type_ids_list = [] |
|
for ids_sample in input_ids_list_of_lists: |
|
types = [0] * len(ids_sample) |
|
for j, token_id_val in enumerate(ids_sample): |
|
if self.image_token_id is not None and token_id_val == self.image_token_id: |
|
types[j] = 1 |
|
elif self.audio_token_id != -1 and token_id_val == self.audio_token_id: |
|
types[j] = 2 |
|
token_type_ids_list.append(types) |
|
text_features_dict["token_type_ids"] = token_type_ids_list |
|
|
|
final_batch_data = {**text_features_dict} |
|
if image_features_dict: final_batch_data.update(image_features_dict) |
|
if audio_features_dict: final_batch_data.update(audio_features_dict) |
|
|
|
|
|
return BatchFeature(data=final_batch_data, tensor_type=final_rt) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self) -> List[str]: |
|
input_names = set() |
|
if hasattr(self, 'tokenizer') and self.tokenizer is not None: |
|
|
|
tokenizer_inputs = self.tokenizer.model_input_names |
|
if isinstance(tokenizer_inputs, (list, set)): |
|
input_names.update(tokenizer_inputs) |
|
else: |
|
input_names.add(str(tokenizer_inputs)) |
|
input_names.add("token_type_ids") |
|
|
|
if hasattr(self, 'image_processor') and self.image_processor is not None: |
|
|
|
image_inputs = self.image_processor.model_input_names |
|
if isinstance(image_inputs, (list, set)): |
|
input_names.update(image_inputs) |
|
else: |
|
input_names.add(str(image_inputs)) |
|
|
|
if hasattr(self, 'audio_processor') and self.audio_processor is not None: |
|
|
|
|
|
audio_inputs = self.audio_processor.model_input_names |
|
if isinstance(audio_inputs, (list, set)): |
|
input_names.update(audio_inputs) |
|
else: |
|
input_names.add(str(audio_inputs)) |
|
|
|
return list(input_names) |
|
|