File size: 6,431 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

import os
import os.path as osp
from itertools import chain
from typing import Any, List, Optional

import torch
import torch.distributed as dist
from hydra.utils import instantiate
from torch.utils.data import ConcatDataset, Dataset
from transformers import PreTrainedTokenizer

from llava.data.datasets_mixture import DATASETS_LEGACY
from llava.train.args import DataArguments, TrainingArguments
from llava.utils import io
from llava.utils.logging import logger
import time
import numpy as np
__all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"]


def load_dataset_yaml(name):
    fname = f"{name}.yaml" if not name.endswith(".yaml") else name

    # yaml under llava/data/registry/datasets
    repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname)
    if osp.exists(repo_path):
        return repo_path

    # # yaml under <fs yaml path>
    abs_path = osp.expanduser(fname)
    if osp.exists(abs_path):
        return abs_path

    raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.")


def register_datasets(name: Optional[str] = None):
    if name is None:
        name = os.environ.get("VILA_DATASETS", "default")
        logger.info(f"Registering datasets from environment: '{name}'.")
    # return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml"))
    dataset_meta = {}
    for _name in name.split(","):
        yamlpath = load_dataset_yaml(_name)
        logger.info(f"Registering datasets from: '{yamlpath}'.")
        meta = io.load(yamlpath)
        dataset_meta.update(meta)
    return dataset_meta


def register_mixtures():
    return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml"))


DATASETS = register_datasets()
MIXTURES = register_mixtures()


def parse_mixture(mixture: str) -> List[str]:
    names = mixture.split("+") if "+" in mixture else [mixture]
    while any(name in MIXTURES for name in names):
        names = list(chain(*[MIXTURES.get(name, [name]) for name in names]))
    return sorted(names)


class SubsetDataset(Dataset):
    def __init__(self, dataset: Dataset, limit: int) -> None:
        super().__init__()
        self.dataset = dataset
        self.limit = limit

    def __len__(self) -> int:
        return int(len(self.dataset) * self.limit)

    def __getitem__(self, index: int) -> Any:
        return self.dataset[index % len(self.dataset)]

class RepeatedDataset(Dataset):
    def __init__(self, dataset: Dataset, times: int) -> None:
        super().__init__()
        self.dataset = dataset
        self.times = times

    def __len__(self) -> int:
        return len(self.dataset) * self.times

    def __getitem__(self, index: int) -> Any:
        return self.dataset[index % len(self.dataset)]


def get_world_size():
    if torch.distributed.is_initialized():
        return torch.distributed.get_world_size()
    else:
        return 1


def build_dataset(
    mixture: str,
    data_args: DataArguments,
    training_args: TrainingArguments,
    tokenizer: PreTrainedTokenizer,
) -> Dataset:
    logger.warning(f"Training VILA with mixture '{mixture}'.")
    datasets = []
    dataset_rng = np.random.default_rng(1234)
    for name in parse_mixture(mixture):        

        if "*" in name:
            name, times = name.split("*")
            times = int(times)
        else:
            times = 1
        limit_dataset = False
        if "#" in name:
            # we limit the max length of this dataset
            name, max_length_percent = name.split("#")
            limit_dataset = True
        if DATASETS is not None and name in DATASETS:
            if name in DATASETS_LEGACY:
                logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.")
            dataset = instantiate(DATASETS[name], _partial_=True)(
                tokenizer=tokenizer,
                data_args=data_args,
                global_batch_size=(
                    training_args.per_device_train_batch_size
                    # * torch.distributed.get_world_size()
                    * get_world_size()
                    * training_args.gradient_accumulation_steps
                ),
            )
        elif name in DATASETS_LEGACY:
            logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.")
            dataset = build_dataset_legacy(
                name,
                data_args=data_args,
                training_args=training_args,
                tokenizer=tokenizer,
            )
        else:
            raise ValueError(f"Dataset '{name}' is not found in the registries.")

        
        if limit_dataset:
            # we limit the max length of this dataset
            max_length = int(float(int(max_length_percent) / 100.) * len(dataset))
            dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.))

        if times > 1:
            dataset = RepeatedDataset(dataset, times)
        datasets.append(dataset)
    return ConcatDataset(datasets)


def build_dataset_legacy(
    name: str,
    data_args: DataArguments,
    training_args: TrainingArguments,
    tokenizer: PreTrainedTokenizer,
) -> Dataset:
    from llava.data.dataset import (
        LazySupervisedDataset,
        LazyWDSDataset,
    )

    dataset = DATASETS_LEGACY[name]
    dataset_type = dataset.dataset_type
    if dataset_type == "torch":
        dataset_cls = LazySupervisedDataset
    elif dataset_type == "wds":
        dataset_cls = LazyWDSDataset
    else:
        raise NotImplementedError(f"{dataset_type} is not supported.")

    data_args.meta_path = getattr(dataset, "meta_path", None)
    data_args.caption_choice = getattr(dataset, "caption_choice", None)
    data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None)
    data_args.start_idx = getattr(dataset, "start_idx", None)
    data_args.end_idx = getattr(dataset, "end_idx", None)

    return dataset_cls(
        tokenizer=tokenizer,
        data_path=dataset.data_path,
        image_folder=getattr(dataset, "image_path"),
        data_args=data_args,
        training_args=training_args,
    )