Spaces:
Running
on
A100
Running
on
A100
File size: 6,431 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
import os
import os.path as osp
from itertools import chain
from typing import Any, List, Optional
import torch
import torch.distributed as dist
from hydra.utils import instantiate
from torch.utils.data import ConcatDataset, Dataset
from transformers import PreTrainedTokenizer
from llava.data.datasets_mixture import DATASETS_LEGACY
from llava.train.args import DataArguments, TrainingArguments
from llava.utils import io
from llava.utils.logging import logger
import time
import numpy as np
__all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"]
def load_dataset_yaml(name):
fname = f"{name}.yaml" if not name.endswith(".yaml") else name
# yaml under llava/data/registry/datasets
repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname)
if osp.exists(repo_path):
return repo_path
# # yaml under <fs yaml path>
abs_path = osp.expanduser(fname)
if osp.exists(abs_path):
return abs_path
raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.")
def register_datasets(name: Optional[str] = None):
if name is None:
name = os.environ.get("VILA_DATASETS", "default")
logger.info(f"Registering datasets from environment: '{name}'.")
# return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml"))
dataset_meta = {}
for _name in name.split(","):
yamlpath = load_dataset_yaml(_name)
logger.info(f"Registering datasets from: '{yamlpath}'.")
meta = io.load(yamlpath)
dataset_meta.update(meta)
return dataset_meta
def register_mixtures():
return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml"))
DATASETS = register_datasets()
MIXTURES = register_mixtures()
def parse_mixture(mixture: str) -> List[str]:
names = mixture.split("+") if "+" in mixture else [mixture]
while any(name in MIXTURES for name in names):
names = list(chain(*[MIXTURES.get(name, [name]) for name in names]))
return sorted(names)
class SubsetDataset(Dataset):
def __init__(self, dataset: Dataset, limit: int) -> None:
super().__init__()
self.dataset = dataset
self.limit = limit
def __len__(self) -> int:
return int(len(self.dataset) * self.limit)
def __getitem__(self, index: int) -> Any:
return self.dataset[index % len(self.dataset)]
class RepeatedDataset(Dataset):
def __init__(self, dataset: Dataset, times: int) -> None:
super().__init__()
self.dataset = dataset
self.times = times
def __len__(self) -> int:
return len(self.dataset) * self.times
def __getitem__(self, index: int) -> Any:
return self.dataset[index % len(self.dataset)]
def get_world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def build_dataset(
mixture: str,
data_args: DataArguments,
training_args: TrainingArguments,
tokenizer: PreTrainedTokenizer,
) -> Dataset:
logger.warning(f"Training VILA with mixture '{mixture}'.")
datasets = []
dataset_rng = np.random.default_rng(1234)
for name in parse_mixture(mixture):
if "*" in name:
name, times = name.split("*")
times = int(times)
else:
times = 1
limit_dataset = False
if "#" in name:
# we limit the max length of this dataset
name, max_length_percent = name.split("#")
limit_dataset = True
if DATASETS is not None and name in DATASETS:
if name in DATASETS_LEGACY:
logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.")
dataset = instantiate(DATASETS[name], _partial_=True)(
tokenizer=tokenizer,
data_args=data_args,
global_batch_size=(
training_args.per_device_train_batch_size
# * torch.distributed.get_world_size()
* get_world_size()
* training_args.gradient_accumulation_steps
),
)
elif name in DATASETS_LEGACY:
logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.")
dataset = build_dataset_legacy(
name,
data_args=data_args,
training_args=training_args,
tokenizer=tokenizer,
)
else:
raise ValueError(f"Dataset '{name}' is not found in the registries.")
if limit_dataset:
# we limit the max length of this dataset
max_length = int(float(int(max_length_percent) / 100.) * len(dataset))
dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.))
if times > 1:
dataset = RepeatedDataset(dataset, times)
datasets.append(dataset)
return ConcatDataset(datasets)
def build_dataset_legacy(
name: str,
data_args: DataArguments,
training_args: TrainingArguments,
tokenizer: PreTrainedTokenizer,
) -> Dataset:
from llava.data.dataset import (
LazySupervisedDataset,
LazyWDSDataset,
)
dataset = DATASETS_LEGACY[name]
dataset_type = dataset.dataset_type
if dataset_type == "torch":
dataset_cls = LazySupervisedDataset
elif dataset_type == "wds":
dataset_cls = LazyWDSDataset
else:
raise NotImplementedError(f"{dataset_type} is not supported.")
data_args.meta_path = getattr(dataset, "meta_path", None)
data_args.caption_choice = getattr(dataset, "caption_choice", None)
data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None)
data_args.start_idx = getattr(dataset, "start_idx", None)
data_args.end_idx = getattr(dataset, "end_idx", None)
return dataset_cls(
tokenizer=tokenizer,
data_path=dataset.data_path,
image_folder=getattr(dataset, "image_path"),
data_args=data_args,
training_args=training_args,
)
|