Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
from dataclasses import dataclass | |
from typing import Any, Dict, Sequence | |
import torch | |
from transformers import PreTrainedTokenizer | |
from llava.constants import IGNORE_INDEX | |
from llava.utils.logging import logger | |
__all__ = ["DataCollator"] | |
class DataCollator: | |
tokenizer: PreTrainedTokenizer | |
def __init__(self, tokenizer: PreTrainedTokenizer): | |
super().__init__() | |
self.tokenizer = tokenizer | |
def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]: | |
# Gather everything from the batch | |
input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, [] | |
media_meta = {} | |
media_meta["sound_feature_masks"] = [] | |
media_meta["sound_embed_masks"] = [] | |
media_meta["frame_times"] = [] | |
for instance in instances: | |
if isinstance(instance["input_ids"], torch.Tensor): | |
input_ids.append(instance["input_ids"]) | |
labels.append(instance["labels"]) | |
for name in media: | |
objs = instance.get(name) | |
objs = objs if objs is not None else [] | |
media[name].append([obj for obj in objs]) | |
if instance.get("sound") is not None: | |
for name_k in media_meta: | |
if "sound" in name_k: | |
objs = instance.get(name_k) | |
media_meta[name_k].append([obj for obj in objs]) | |
if instance.get("video") is not None or instance.get("image") is not None: | |
for name_k in media_meta: | |
if "frame" in name_k: | |
objs = instance.get(name_k) | |
media_meta[name_k].append([obj for obj in objs]) | |
if "block_sizes" in instance: | |
block_sizes.append(instance["block_sizes"]) | |
else: | |
block_sizes.append( | |
[None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else [] | |
) | |
else: | |
input_ids.extend(instance["input_ids"]) | |
labels.extend(instance["labels"]) | |
for name in media: | |
objs = instance.get(name) | |
objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))] | |
media[name].extend(objs) | |
if instance.get("sound") is not None: | |
for name_k in media_meta: | |
if "sound" in name_k: | |
objs = instance.get(name_k) | |
media_meta[name_k].extend(objs) | |
if instance.get("video") is not None or instance.get("image") is not None: | |
for name_k in media_meta: | |
if "frame" in name_k: | |
objs = instance.get(name_k) | |
media_meta[name_k].append([obj for obj in objs]) | |
if "block_sizes" in instance: | |
block_sizes.extend(instance["block_sizes"]) | |
else: | |
block_sizes.extend( | |
[[None for _ in range(len(objs))] for objs in instance.get("image")] | |
if instance.get("image") is not None | |
else [[] for _ in range(len(instance["input_ids"]))] | |
) | |
batch_size = len(input_ids) | |
# Check if the number of media objects (or the number of block sizes) matches the number of media tokens | |
for name in media: | |
for k in range(batch_size): | |
if name == "image" and not all([_ is None for _ in block_sizes[k]]): | |
actual = len(block_sizes[k]) | |
else: | |
actual = len(media[name][k]) | |
expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() | |
if actual != expected: | |
raise ValueError( | |
f"Number mismatch between {name} objects and {name} tokens. " | |
f"There are {expected} {name} tokens but {actual} {name} objects." | |
) | |
# Batchify the inputs | |
input_ids = torch.nn.utils.rnn.pad_sequence( | |
input_ids, | |
batch_first=True, | |
padding_value=self.tokenizer.pad_token_id, | |
) | |
labels = torch.nn.utils.rnn.pad_sequence( | |
labels, | |
batch_first=True, | |
padding_value=IGNORE_INDEX, | |
) | |
input_ids = input_ids[:, : self.tokenizer.model_max_length] | |
labels = labels[:, : self.tokenizer.model_max_length] | |
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) | |
# Truncate media objects if necessary | |
for name in media: | |
objects = [] | |
for k in range(batch_size): | |
if name == "image" and not all([_ is None for _ in block_sizes[k]]): | |
actual = len(media[name][k]) | |
num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]]) | |
num_small_scale_blocks = actual - num_large_scale_blocks | |
num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k]) | |
expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() | |
expected = ( | |
sum([x * y for x, y in block_sizes[k][:expected_full_image]]) | |
+ num_small_scale_blocks_each_img * expected_full_image | |
) | |
if actual > expected: | |
logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") | |
media[name][k] = media[name][k][:expected] | |
objects.extend(media[name][k]) | |
block_sizes[k] = block_sizes[k][:expected_full_image] | |
else: | |
actual = len(media[name][k]) | |
expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() | |
if actual > expected: | |
logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") | |
media[name][k] = media[name][k][:expected] | |
objects.extend(media[name][k]) | |
if name == "image": | |
block_sizes[k] = block_sizes[k][:expected] | |
media[name] = objects | |
for name in media_meta: | |
objects = [] | |
for k in range(batch_size): | |
try: | |
objects.extend(media_meta[name][k]) | |
except: | |
continue | |
media_meta[name] = objects | |
# Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...] | |
block_sizes = sum(block_sizes, []) | |
return { | |
"input_ids": input_ids, | |
"media": media, | |
"media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}}, | |
"labels": labels, | |
"attention_mask": attention_mask, | |
"media_meta": media_meta, | |
} | |