Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional | |
| import torchdata.datapipes.iter | |
| import webdataset as wds | |
| from omegaconf import DictConfig | |
| from pytorch_lightning import LightningDataModule | |
| try: | |
| from sdata import create_dataset, create_dummy_dataset, create_loader | |
| except ImportError as e: | |
| print("#" * 100) | |
| print("Datasets not yet available") | |
| print("to enable, we need to add stable-datasets as a submodule") | |
| print("please use ``git submodule update --init --recursive``") | |
| print("and do ``pip install -e stable-datasets/`` from the root of this repo") | |
| print("#" * 100) | |
| exit(1) | |
| class StableDataModuleFromConfig(LightningDataModule): | |
| def __init__( | |
| self, | |
| train: DictConfig, | |
| validation: Optional[DictConfig] = None, | |
| test: Optional[DictConfig] = None, | |
| skip_val_loader: bool = False, | |
| dummy: bool = False, | |
| ): | |
| super().__init__() | |
| self.train_config = train | |
| assert ( | |
| "datapipeline" in self.train_config and "loader" in self.train_config | |
| ), "train config requires the fields `datapipeline` and `loader`" | |
| self.val_config = validation | |
| if not skip_val_loader: | |
| if self.val_config is not None: | |
| assert ( | |
| "datapipeline" in self.val_config and "loader" in self.val_config | |
| ), "validation config requires the fields `datapipeline` and `loader`" | |
| else: | |
| print( | |
| "Warning: No Validation datapipeline defined, using that one from training" | |
| ) | |
| self.val_config = train | |
| self.test_config = test | |
| if self.test_config is not None: | |
| assert ( | |
| "datapipeline" in self.test_config and "loader" in self.test_config | |
| ), "test config requires the fields `datapipeline` and `loader`" | |
| self.dummy = dummy | |
| if self.dummy: | |
| print("#" * 100) | |
| print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") | |
| print("#" * 100) | |
| def setup(self, stage: str) -> None: | |
| print("Preparing datasets") | |
| if self.dummy: | |
| data_fn = create_dummy_dataset | |
| else: | |
| data_fn = create_dataset | |
| self.train_datapipeline = data_fn(**self.train_config.datapipeline) | |
| if self.val_config: | |
| self.val_datapipeline = data_fn(**self.val_config.datapipeline) | |
| if self.test_config: | |
| self.test_datapipeline = data_fn(**self.test_config.datapipeline) | |
| def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: | |
| loader = create_loader(self.train_datapipeline, **self.train_config.loader) | |
| return loader | |
| def val_dataloader(self) -> wds.DataPipeline: | |
| return create_loader(self.val_datapipeline, **self.val_config.loader) | |
| def test_dataloader(self) -> wds.DataPipeline: | |
| return create_loader(self.test_datapipeline, **self.test_config.loader) | |