JasonSmithSO's picture
Upload 777 files
0034848 verified
raw
history blame
5.42 kB
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import os.path as op
import torch
import logging
import code
from custom_mesh_graphormer.utils.comm import get_world_size
from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset)
from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset)
def build_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor)
class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
"""
Wraps a BatchSampler, resampling from it until
a specified number of iterations have been sampled
"""
def __init__(self, batch_sampler, num_iterations, start_iter=0):
self.batch_sampler = batch_sampler
self.num_iterations = num_iterations
self.start_iter = start_iter
def __iter__(self):
iteration = self.start_iter
while iteration <= self.num_iterations:
# if the underlying sampler has a set_epoch method, like
# DistributedSampler, used for making each process see
# a different split of the dataset, then set it
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(iteration)
for batch in self.batch_sampler:
iteration += 1
if iteration > self.num_iterations:
break
yield batch
def __len__(self):
return self.num_iterations
def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0):
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, images_per_gpu, drop_last=False
)
if num_iters is not None and num_iters >= 0:
batch_sampler = IterationBasedBatchSampler(
batch_sampler, num_iters, start_iter
)
return batch_sampler
def make_data_sampler(dataset, shuffle, distributed):
if distributed:
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
def make_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader
#==============================================================================================
def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor)
def make_hand_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader