Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
import os | |
import os.path as osp | |
from itertools import chain | |
from typing import Any, List, Optional | |
import torch | |
import torch.distributed as dist | |
from hydra.utils import instantiate | |
from torch.utils.data import ConcatDataset, Dataset | |
from transformers import PreTrainedTokenizer | |
from llava.data.datasets_mixture import DATASETS_LEGACY | |
from llava.train.args import DataArguments, TrainingArguments | |
from llava.utils import io | |
from llava.utils.logging import logger | |
import time | |
import numpy as np | |
__all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"] | |
def load_dataset_yaml(name): | |
fname = f"{name}.yaml" if not name.endswith(".yaml") else name | |
# yaml under llava/data/registry/datasets | |
repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname) | |
if osp.exists(repo_path): | |
return repo_path | |
# # yaml under <fs yaml path> | |
abs_path = osp.expanduser(fname) | |
if osp.exists(abs_path): | |
return abs_path | |
raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.") | |
def register_datasets(name: Optional[str] = None): | |
if name is None: | |
name = os.environ.get("VILA_DATASETS", "default") | |
logger.info(f"Registering datasets from environment: '{name}'.") | |
# return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml")) | |
dataset_meta = {} | |
for _name in name.split(","): | |
yamlpath = load_dataset_yaml(_name) | |
logger.info(f"Registering datasets from: '{yamlpath}'.") | |
meta = io.load(yamlpath) | |
dataset_meta.update(meta) | |
return dataset_meta | |
def register_mixtures(): | |
return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml")) | |
DATASETS = register_datasets() | |
MIXTURES = register_mixtures() | |
def parse_mixture(mixture: str) -> List[str]: | |
names = mixture.split("+") if "+" in mixture else [mixture] | |
while any(name in MIXTURES for name in names): | |
names = list(chain(*[MIXTURES.get(name, [name]) for name in names])) | |
return sorted(names) | |
class SubsetDataset(Dataset): | |
def __init__(self, dataset: Dataset, limit: int) -> None: | |
super().__init__() | |
self.dataset = dataset | |
self.limit = limit | |
def __len__(self) -> int: | |
return int(len(self.dataset) * self.limit) | |
def __getitem__(self, index: int) -> Any: | |
return self.dataset[index % len(self.dataset)] | |
class RepeatedDataset(Dataset): | |
def __init__(self, dataset: Dataset, times: int) -> None: | |
super().__init__() | |
self.dataset = dataset | |
self.times = times | |
def __len__(self) -> int: | |
return len(self.dataset) * self.times | |
def __getitem__(self, index: int) -> Any: | |
return self.dataset[index % len(self.dataset)] | |
def get_world_size(): | |
if torch.distributed.is_initialized(): | |
return torch.distributed.get_world_size() | |
else: | |
return 1 | |
def build_dataset( | |
mixture: str, | |
data_args: DataArguments, | |
training_args: TrainingArguments, | |
tokenizer: PreTrainedTokenizer, | |
) -> Dataset: | |
logger.warning(f"Training VILA with mixture '{mixture}'.") | |
datasets = [] | |
dataset_rng = np.random.default_rng(1234) | |
for name in parse_mixture(mixture): | |
if "*" in name: | |
name, times = name.split("*") | |
times = int(times) | |
else: | |
times = 1 | |
limit_dataset = False | |
if "#" in name: | |
# we limit the max length of this dataset | |
name, max_length_percent = name.split("#") | |
limit_dataset = True | |
if DATASETS is not None and name in DATASETS: | |
if name in DATASETS_LEGACY: | |
logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.") | |
dataset = instantiate(DATASETS[name], _partial_=True)( | |
tokenizer=tokenizer, | |
data_args=data_args, | |
global_batch_size=( | |
training_args.per_device_train_batch_size | |
# * torch.distributed.get_world_size() | |
* get_world_size() | |
* training_args.gradient_accumulation_steps | |
), | |
) | |
elif name in DATASETS_LEGACY: | |
logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.") | |
dataset = build_dataset_legacy( | |
name, | |
data_args=data_args, | |
training_args=training_args, | |
tokenizer=tokenizer, | |
) | |
else: | |
raise ValueError(f"Dataset '{name}' is not found in the registries.") | |
if limit_dataset: | |
# we limit the max length of this dataset | |
max_length = int(float(int(max_length_percent) / 100.) * len(dataset)) | |
dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.)) | |
if times > 1: | |
dataset = RepeatedDataset(dataset, times) | |
datasets.append(dataset) | |
return ConcatDataset(datasets) | |
def build_dataset_legacy( | |
name: str, | |
data_args: DataArguments, | |
training_args: TrainingArguments, | |
tokenizer: PreTrainedTokenizer, | |
) -> Dataset: | |
from llava.data.dataset import ( | |
LazySupervisedDataset, | |
LazyWDSDataset, | |
) | |
dataset = DATASETS_LEGACY[name] | |
dataset_type = dataset.dataset_type | |
if dataset_type == "torch": | |
dataset_cls = LazySupervisedDataset | |
elif dataset_type == "wds": | |
dataset_cls = LazyWDSDataset | |
else: | |
raise NotImplementedError(f"{dataset_type} is not supported.") | |
data_args.meta_path = getattr(dataset, "meta_path", None) | |
data_args.caption_choice = getattr(dataset, "caption_choice", None) | |
data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None) | |
data_args.start_idx = getattr(dataset, "start_idx", None) | |
data_args.end_idx = getattr(dataset, "end_idx", None) | |
return dataset_cls( | |
tokenizer=tokenizer, | |
data_path=dataset.data_path, | |
image_folder=getattr(dataset, "image_path"), | |
data_args=data_args, | |
training_args=training_args, | |
) | |