IsshikiHugh's picture
feat: CPU demo
5ac1897
from typing import Dict, Optional
import torch
import numpy as np
import pytorch_lightning as pl
from yacs.config import CfgNode
import webdataset as wds
from .dataset import Dataset
from .image_dataset import ImageDataset
from .mocap_dataset import MoCapDataset
def to_lower(x: Dict) -> Dict:
"""
Convert all dictionary keys to lowercase
Args:
x (dict): Input dictionary
Returns:
dict: Output dictionary with all keys converted to lowercase
"""
return {k.lower(): v for k, v in x.items()}
def create_dataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True, **kwargs) -> Dataset:
"""
Instantiate a dataset from a config file.
Args:
cfg (CfgNode): Model configuration file.
dataset_cfg (CfgNode): Dataset configuration info.
train (bool): Variable to select between train and val datasets.
"""
dataset_type = Dataset.registry[dataset_cfg.TYPE]
return dataset_type(cfg, **to_lower(dataset_cfg), train=train, **kwargs)
def create_webdataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> Dataset:
"""
Like `create_dataset` but load data from tars.
"""
dataset_type = Dataset.registry[dataset_cfg.TYPE]
return dataset_type.load_tars_as_webdataset(cfg, **to_lower(dataset_cfg), train=train)
class MixedWebDataset(wds.WebDataset):
def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> None:
super(wds.WebDataset, self).__init__()
dataset_list = cfg.DATASETS.TRAIN if train else cfg.DATASETS.VAL
datasets = [create_webdataset(cfg, dataset_cfg[dataset], train=train) for dataset, v in dataset_list.items()]
weights = np.array([v.WEIGHT for dataset, v in dataset_list.items()])
weights = weights / weights.sum() # normalize
self.append(wds.RandomMix(datasets, weights))
class HMR2DataModule(pl.LightningDataModule):
def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None:
"""
Initialize LightningDataModule for HMR2 training
Args:
cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
dataset_cfg (CfgNode): Dataset configuration file
"""
super().__init__()
self.cfg = cfg
self.dataset_cfg = dataset_cfg
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.mocap_dataset = None
def setup(self, stage: Optional[str] = None) -> None:
"""
Load datasets necessary for training
Args:
cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
"""
if self.train_dataset == None:
self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000)
# self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000)
self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP]))
def train_dataloader(self) -> Dict:
"""
Setup training data loader.
Returns:
Dict: Dictionary containing image and mocap data dataloaders
"""
train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR)
mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1)
return {'img': train_dataloader, 'mocap': mocap_dataloader}
# return {'img': train_dataloader}
# def val_dataloader(self) -> torch.utils.data.DataLoader:
# """
# Setup val data loader.
# Returns:
# torch.utils.data.DataLoader: Validation dataloader
# """
# val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS)
# return val_dataloader