File size: 4,150 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Dict, Optional

import torch
import numpy as np
import pytorch_lightning as pl
from yacs.config import CfgNode

import webdataset as wds
from .dataset import Dataset
from .image_dataset import ImageDataset
from .mocap_dataset import MoCapDataset

def to_lower(x: Dict) -> Dict:
    """
    Convert all dictionary keys to lowercase
    Args:
      x (dict): Input dictionary
    Returns:
      dict: Output dictionary with all keys converted to lowercase
    """
    return {k.lower(): v for k, v in x.items()}

def create_dataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True, **kwargs) -> Dataset:
    """
    Instantiate a dataset from a config file.
    Args:
        cfg (CfgNode): Model configuration file.
        dataset_cfg (CfgNode): Dataset configuration info.
        train (bool): Variable to select between train and val datasets.
    """

    dataset_type = Dataset.registry[dataset_cfg.TYPE]
    return dataset_type(cfg, **to_lower(dataset_cfg), train=train, **kwargs)

def create_webdataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> Dataset:
    """
    Like `create_dataset` but load data from tars.
    """
    dataset_type = Dataset.registry[dataset_cfg.TYPE]
    return dataset_type.load_tars_as_webdataset(cfg, **to_lower(dataset_cfg), train=train)


class MixedWebDataset(wds.WebDataset):
    def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> None:
        super(wds.WebDataset, self).__init__()
        dataset_list = cfg.DATASETS.TRAIN if train else cfg.DATASETS.VAL
        datasets = [create_webdataset(cfg, dataset_cfg[dataset], train=train) for dataset, v in dataset_list.items()]
        weights = np.array([v.WEIGHT for dataset, v in dataset_list.items()])
        weights = weights / weights.sum()  # normalize
        self.append(wds.RandomMix(datasets, weights))

class HMR2DataModule(pl.LightningDataModule):

    def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None:
        """
        Initialize LightningDataModule for HMR2 training
        Args:
            cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
            dataset_cfg (CfgNode): Dataset configuration file
        """
        super().__init__()
        self.cfg = cfg
        self.dataset_cfg = dataset_cfg
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.mocap_dataset = None

    def setup(self, stage: Optional[str] = None) -> None:
        """
        Load datasets necessary for training
        Args:
            cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
        """
        if self.train_dataset == None:
            self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000)
            # self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000)
            self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP]))

    def train_dataloader(self) -> Dict:
        """
        Setup training data loader.
        Returns:
            Dict: Dictionary containing image and mocap data dataloaders
        """
        train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR)
        mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1)
        return {'img': train_dataloader, 'mocap': mocap_dataloader}
        # return {'img': train_dataloader}

    # def val_dataloader(self) -> torch.utils.data.DataLoader:
    #     """
    #     Setup val data loader.
    #     Returns:
    #         torch.utils.data.DataLoader: Validation dataloader
    #     """
    #     val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS)
    #     return val_dataloader