SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import PreTrainedTokenizer
from llava.constants import IGNORE_INDEX
from llava.utils.logging import logger
__all__ = ["DataCollator"]
@dataclass
class DataCollator:
tokenizer: PreTrainedTokenizer
def __init__(self, tokenizer: PreTrainedTokenizer):
super().__init__()
self.tokenizer = tokenizer
def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
# Gather everything from the batch
input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, []
media_meta = {}
media_meta["sound_feature_masks"] = []
media_meta["sound_embed_masks"] = []
media_meta["frame_times"] = []
for instance in instances:
if isinstance(instance["input_ids"], torch.Tensor):
input_ids.append(instance["input_ids"])
labels.append(instance["labels"])
for name in media:
objs = instance.get(name)
objs = objs if objs is not None else []
media[name].append([obj for obj in objs])
if instance.get("sound") is not None:
for name_k in media_meta:
if "sound" in name_k:
objs = instance.get(name_k)
media_meta[name_k].append([obj for obj in objs])
if instance.get("video") is not None or instance.get("image") is not None:
for name_k in media_meta:
if "frame" in name_k:
objs = instance.get(name_k)
media_meta[name_k].append([obj for obj in objs])
if "block_sizes" in instance:
block_sizes.append(instance["block_sizes"])
else:
block_sizes.append(
[None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else []
)
else:
input_ids.extend(instance["input_ids"])
labels.extend(instance["labels"])
for name in media:
objs = instance.get(name)
objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))]
media[name].extend(objs)
if instance.get("sound") is not None:
for name_k in media_meta:
if "sound" in name_k:
objs = instance.get(name_k)
media_meta[name_k].extend(objs)
if instance.get("video") is not None or instance.get("image") is not None:
for name_k in media_meta:
if "frame" in name_k:
objs = instance.get(name_k)
media_meta[name_k].append([obj for obj in objs])
if "block_sizes" in instance:
block_sizes.extend(instance["block_sizes"])
else:
block_sizes.extend(
[[None for _ in range(len(objs))] for objs in instance.get("image")]
if instance.get("image") is not None
else [[] for _ in range(len(instance["input_ids"]))]
)
batch_size = len(input_ids)
# Check if the number of media objects (or the number of block sizes) matches the number of media tokens
for name in media:
for k in range(batch_size):
if name == "image" and not all([_ is None for _ in block_sizes[k]]):
actual = len(block_sizes[k])
else:
actual = len(media[name][k])
expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
if actual != expected:
raise ValueError(
f"Number mismatch between {name} objects and {name} tokens. "
f"There are {expected} {name} tokens but {actual} {name} objects."
)
# Batchify the inputs
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels,
batch_first=True,
padding_value=IGNORE_INDEX,
)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
# Truncate media objects if necessary
for name in media:
objects = []
for k in range(batch_size):
if name == "image" and not all([_ is None for _ in block_sizes[k]]):
actual = len(media[name][k])
num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]])
num_small_scale_blocks = actual - num_large_scale_blocks
num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k])
expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
expected = (
sum([x * y for x, y in block_sizes[k][:expected_full_image]])
+ num_small_scale_blocks_each_img * expected_full_image
)
if actual > expected:
logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
media[name][k] = media[name][k][:expected]
objects.extend(media[name][k])
block_sizes[k] = block_sizes[k][:expected_full_image]
else:
actual = len(media[name][k])
expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
if actual > expected:
logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
media[name][k] = media[name][k][:expected]
objects.extend(media[name][k])
if name == "image":
block_sizes[k] = block_sizes[k][:expected]
media[name] = objects
for name in media_meta:
objects = []
for k in range(batch_size):
try:
objects.extend(media_meta[name][k])
except:
continue
media_meta[name] = objects
# Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...]
block_sizes = sum(block_sizes, [])
return {
"input_ids": input_ids,
"media": media,
"media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}},
"labels": labels,
"attention_mask": attention_mask,
"media_meta": media_meta,
}