# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import warnings from dataclasses import dataclass, field @dataclass class Dataset: dataset_name: str dataset_type: str = field(default="torch") data_path: str = field(default=None, metadata={"help": "Path to the training data."}) meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."}) image_path: str = field(default=None, metadata={"help": "Path to the training image data."}) speech_path: str = field(default=None, metadata={"help": "Path to the training speech data."}) caption_choice: str = field(default=None, metadata={"help": "Path to the caption directory for recaption."}) description: str = field( default=None, metadata={ "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset." }, ) test_script: str = (None,) maintainer: str = (None,) ############## ############## ############## ############## ############## ############## caption_choice: str = field(default=None, metadata={"help": "Path to the captions for webdataset."}) caption_choice_2: str = field(default=None, metadata={"help": "Path to the captions for webdataset."}) start_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."}) end_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."}) DATASETS_LEGACY = {} def add_dataset(dataset): if dataset.dataset_name in DATASETS_LEGACY: # make sure the data_name is unique warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.") assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'." DATASETS_LEGACY.update({dataset.dataset_name: dataset}) def register_datasets_mixtures(): ############## ############## ############## ############## ############## ############## # Audio Datasets ############## ############## ############## ############## ############## ############## data_mixture_1 = Dataset( dataset_name="data_mixture_1", dataset_type="torch", data_path="/path/to/your/data_mixture_1/train.json", ) add_dataset(data_mixture_1) data_mixture_2 = Dataset( dataset_name="data_mixture_2", dataset_type="torch", data_path="/path/to/your/data_mixture_2/train.json", ) add_dataset(data_mixture_2) # Add more data mixtures below