Spaces:
Paused
Paused
| import argparse | |
| import os | |
| import shutil | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.parallel | |
| import torch.backends.cudnn as cudnn | |
| import torch.distributed as dist | |
| import torch.optim | |
| import torch.utils.data | |
| import torch.utils.data.distributed | |
| import torchvision.transforms as transforms | |
| import torchvision.datasets as datasets | |
| import torchvision.models as models | |
| import numpy as np | |
| try: | |
| from apex.parallel import DistributedDataParallel as DDP | |
| from apex.fp16_utils import * | |
| from apex import amp, optimizers | |
| from apex.multi_tensor_apply import multi_tensor_applier | |
| except ImportError: | |
| raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") | |
| def fast_collate(batch, memory_format): | |
| imgs = [img[0] for img in batch] | |
| targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) | |
| w = imgs[0].size[0] | |
| h = imgs[0].size[1] | |
| tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format) | |
| for i, img in enumerate(imgs): | |
| nump_array = np.asarray(img, dtype=np.uint8) | |
| if(nump_array.ndim < 3): | |
| nump_array = np.expand_dims(nump_array, axis=-1) | |
| nump_array = np.rollaxis(nump_array, 2) | |
| tensor[i] += torch.from_numpy(nump_array) | |
| return tensor, targets | |
| def parse(): | |
| model_names = sorted(name for name in models.__dict__ | |
| if name.islower() and not name.startswith("__") | |
| and callable(models.__dict__[name])) | |
| parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') | |
| parser.add_argument('data', metavar='DIR', | |
| help='path to dataset') | |
| parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', | |
| choices=model_names, | |
| help='model architecture: ' + | |
| ' | '.join(model_names) + | |
| ' (default: resnet18)') | |
| parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |
| help='number of data loading workers (default: 4)') | |
| parser.add_argument('--epochs', default=90, type=int, metavar='N', | |
| help='number of total epochs to run') | |
| parser.add_argument('--start-epoch', default=0, type=int, metavar='N', | |
| help='manual epoch number (useful on restarts)') | |
| parser.add_argument('-b', '--batch-size', default=256, type=int, | |
| metavar='N', help='mini-batch size per process (default: 256)') | |
| parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, | |
| metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.') | |
| parser.add_argument('--momentum', default=0.9, type=float, metavar='M', | |
| help='momentum') | |
| parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, | |
| metavar='W', help='weight decay (default: 1e-4)') | |
| parser.add_argument('--print-freq', '-p', default=10, type=int, | |
| metavar='N', help='print frequency (default: 10)') | |
| parser.add_argument('--resume', default='', type=str, metavar='PATH', | |
| help='path to latest checkpoint (default: none)') | |
| parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', | |
| help='evaluate model on validation set') | |
| parser.add_argument('--pretrained', dest='pretrained', action='store_true', | |
| help='use pre-trained model') | |
| parser.add_argument('--prof', default=-1, type=int, | |
| help='Only run 10 iterations for profiling.') | |
| parser.add_argument('--deterministic', action='store_true') | |
| parser.add_argument("--local_rank", default=0, type=int) | |
| parser.add_argument('--sync_bn', action='store_true', | |
| help='enabling apex sync BN.') | |
| parser.add_argument('--opt-level', type=str) | |
| parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) | |
| parser.add_argument('--loss-scale', type=str, default=None) | |
| parser.add_argument('--channels-last', type=bool, default=False) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| global best_prec1, args | |
| args = parse() | |
| print("opt_level = {}".format(args.opt_level)) | |
| print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) | |
| print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) | |
| print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) | |
| cudnn.benchmark = True | |
| best_prec1 = 0 | |
| if args.deterministic: | |
| cudnn.benchmark = False | |
| cudnn.deterministic = True | |
| torch.manual_seed(args.local_rank) | |
| torch.set_printoptions(precision=10) | |
| args.distributed = False | |
| if 'WORLD_SIZE' in os.environ: | |
| args.distributed = int(os.environ['WORLD_SIZE']) > 1 | |
| args.gpu = 0 | |
| args.world_size = 1 | |
| if args.distributed: | |
| args.gpu = args.local_rank | |
| torch.cuda.set_device(args.gpu) | |
| torch.distributed.init_process_group(backend='nccl', | |
| init_method='env://') | |
| args.world_size = torch.distributed.get_world_size() | |
| assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | |
| if args.channels_last: | |
| memory_format = torch.channels_last | |
| else: | |
| memory_format = torch.contiguous_format | |
| # create model | |
| if args.pretrained: | |
| print("=> using pre-trained model '{}'".format(args.arch)) | |
| model = models.__dict__[args.arch](pretrained=True) | |
| else: | |
| print("=> creating model '{}'".format(args.arch)) | |
| model = models.__dict__[args.arch]() | |
| if args.sync_bn: | |
| import apex | |
| print("using apex synced BN") | |
| model = apex.parallel.convert_syncbn_model(model) | |
| model = model.cuda().to(memory_format=memory_format) | |
| # Scale learning rate based on global batch size | |
| args.lr = args.lr*float(args.batch_size*args.world_size)/256. | |
| optimizer = torch.optim.SGD(model.parameters(), args.lr, | |
| momentum=args.momentum, | |
| weight_decay=args.weight_decay) | |
| # Initialize Amp. Amp accepts either values or strings for the optional override arguments, | |
| # for convenient interoperation with argparse. | |
| model, optimizer = amp.initialize(model, optimizer, | |
| opt_level=args.opt_level, | |
| keep_batchnorm_fp32=args.keep_batchnorm_fp32, | |
| loss_scale=args.loss_scale | |
| ) | |
| # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. | |
| # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called | |
| # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter | |
| # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. | |
| if args.distributed: | |
| # By default, apex.parallel.DistributedDataParallel overlaps communication with | |
| # computation in the backward pass. | |
| # model = DDP(model) | |
| # delay_allreduce delays all communication to the end of the backward pass. | |
| model = DDP(model, delay_allreduce=True) | |
| # define loss function (criterion) and optimizer | |
| criterion = nn.CrossEntropyLoss().cuda() | |
| # Optionally resume from a checkpoint | |
| if args.resume: | |
| # Use a local scope to avoid dangling references | |
| def resume(): | |
| if os.path.isfile(args.resume): | |
| print("=> loading checkpoint '{}'".format(args.resume)) | |
| checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) | |
| args.start_epoch = checkpoint['epoch'] | |
| global best_prec1 | |
| best_prec1 = checkpoint['best_prec1'] | |
| model.load_state_dict(checkpoint['state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| print("=> loaded checkpoint '{}' (epoch {})" | |
| .format(args.resume, checkpoint['epoch'])) | |
| else: | |
| print("=> no checkpoint found at '{}'".format(args.resume)) | |
| resume() | |
| # Data loading code | |
| traindir = os.path.join(args.data, 'train') | |
| valdir = os.path.join(args.data, 'val') | |
| if(args.arch == "inception_v3"): | |
| raise RuntimeError("Currently, inception_v3 is not supported by this example.") | |
| # crop_size = 299 | |
| # val_size = 320 # I chose this value arbitrarily, we can adjust. | |
| else: | |
| crop_size = 224 | |
| val_size = 256 | |
| train_dataset = datasets.ImageFolder( | |
| traindir, | |
| transforms.Compose([ | |
| transforms.RandomResizedCrop(crop_size), | |
| transforms.RandomHorizontalFlip(), | |
| # transforms.ToTensor(), Too slow | |
| # normalize, | |
| ])) | |
| val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ | |
| transforms.Resize(val_size), | |
| transforms.CenterCrop(crop_size), | |
| ])) | |
| train_sampler = None | |
| val_sampler = None | |
| if args.distributed: | |
| train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
| val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) | |
| collate_fn = lambda b: fast_collate(b, memory_format) | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), | |
| num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn) | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=args.batch_size, shuffle=False, | |
| num_workers=args.workers, pin_memory=True, | |
| sampler=val_sampler, | |
| collate_fn=collate_fn) | |
| if args.evaluate: | |
| validate(val_loader, model, criterion) | |
| return | |
| for epoch in range(args.start_epoch, args.epochs): | |
| if args.distributed: | |
| train_sampler.set_epoch(epoch) | |
| # train for one epoch | |
| train(train_loader, model, criterion, optimizer, epoch) | |
| # evaluate on validation set | |
| prec1 = validate(val_loader, model, criterion) | |
| # remember best prec@1 and save checkpoint | |
| if args.local_rank == 0: | |
| is_best = prec1 > best_prec1 | |
| best_prec1 = max(prec1, best_prec1) | |
| save_checkpoint({ | |
| 'epoch': epoch + 1, | |
| 'arch': args.arch, | |
| 'state_dict': model.state_dict(), | |
| 'best_prec1': best_prec1, | |
| 'optimizer' : optimizer.state_dict(), | |
| }, is_best) | |
| class data_prefetcher(): | |
| def __init__(self, loader): | |
| self.loader = iter(loader) | |
| self.stream = torch.cuda.Stream() | |
| self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) | |
| self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) | |
| # With Amp, it isn't necessary to manually convert data to half. | |
| # if args.fp16: | |
| # self.mean = self.mean.half() | |
| # self.std = self.std.half() | |
| self.preload() | |
| def preload(self): | |
| try: | |
| self.next_input, self.next_target = next(self.loader) | |
| except StopIteration: | |
| self.next_input = None | |
| self.next_target = None | |
| return | |
| # if record_stream() doesn't work, another option is to make sure device inputs are created | |
| # on the main stream. | |
| # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') | |
| # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') | |
| # Need to make sure the memory allocated for next_* is not still in use by the main stream | |
| # at the time we start copying to next_*: | |
| # self.stream.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(self.stream): | |
| self.next_input = self.next_input.cuda(non_blocking=True) | |
| self.next_target = self.next_target.cuda(non_blocking=True) | |
| # more code for the alternative if record_stream() doesn't work: | |
| # copy_ will record the use of the pinned source tensor in this side stream. | |
| # self.next_input_gpu.copy_(self.next_input, non_blocking=True) | |
| # self.next_target_gpu.copy_(self.next_target, non_blocking=True) | |
| # self.next_input = self.next_input_gpu | |
| # self.next_target = self.next_target_gpu | |
| # With Amp, it isn't necessary to manually convert data to half. | |
| # if args.fp16: | |
| # self.next_input = self.next_input.half() | |
| # else: | |
| self.next_input = self.next_input.float() | |
| self.next_input = self.next_input.sub_(self.mean).div_(self.std) | |
| def next(self): | |
| torch.cuda.current_stream().wait_stream(self.stream) | |
| input = self.next_input | |
| target = self.next_target | |
| if input is not None: | |
| input.record_stream(torch.cuda.current_stream()) | |
| if target is not None: | |
| target.record_stream(torch.cuda.current_stream()) | |
| self.preload() | |
| return input, target | |
| def train(train_loader, model, criterion, optimizer, epoch): | |
| batch_time = AverageMeter() | |
| losses = AverageMeter() | |
| top1 = AverageMeter() | |
| top5 = AverageMeter() | |
| # switch to train mode | |
| model.train() | |
| end = time.time() | |
| prefetcher = data_prefetcher(train_loader) | |
| input, target = prefetcher.next() | |
| i = 0 | |
| while input is not None: | |
| i += 1 | |
| if args.prof >= 0 and i == args.prof: | |
| print("Profiling begun at iteration {}".format(i)) | |
| torch.cuda.cudart().cudaProfilerStart() | |
| if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i)) | |
| adjust_learning_rate(optimizer, epoch, i, len(train_loader)) | |
| # compute output | |
| if args.prof >= 0: torch.cuda.nvtx.range_push("forward") | |
| output = model(input) | |
| if args.prof >= 0: torch.cuda.nvtx.range_pop() | |
| loss = criterion(output, target) | |
| # compute gradient and do SGD step | |
| optimizer.zero_grad() | |
| if args.prof >= 0: torch.cuda.nvtx.range_push("backward") | |
| with amp.scale_loss(loss, optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| if args.prof >= 0: torch.cuda.nvtx.range_pop() | |
| # for param in model.parameters(): | |
| # print(param.data.double().sum().item(), param.grad.data.double().sum().item()) | |
| if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()") | |
| optimizer.step() | |
| if args.prof >= 0: torch.cuda.nvtx.range_pop() | |
| if i%args.print_freq == 0: | |
| # Every print_freq iterations, check the loss, accuracy, and speed. | |
| # For best performance, it doesn't make sense to print these metrics every | |
| # iteration, since they incur an allreduce and some host<->device syncs. | |
| # Measure accuracy | |
| prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
| # Average loss and accuracy across processes for logging | |
| if args.distributed: | |
| reduced_loss = reduce_tensor(loss.data) | |
| prec1 = reduce_tensor(prec1) | |
| prec5 = reduce_tensor(prec5) | |
| else: | |
| reduced_loss = loss.data | |
| # to_python_float incurs a host<->device sync | |
| losses.update(to_python_float(reduced_loss), input.size(0)) | |
| top1.update(to_python_float(prec1), input.size(0)) | |
| top5.update(to_python_float(prec5), input.size(0)) | |
| torch.cuda.synchronize() | |
| batch_time.update((time.time() - end)/args.print_freq) | |
| end = time.time() | |
| if args.local_rank == 0: | |
| print('Epoch: [{0}][{1}/{2}]\t' | |
| 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
| 'Speed {3:.3f} ({4:.3f})\t' | |
| 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' | |
| 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
| 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
| epoch, i, len(train_loader), | |
| args.world_size*args.batch_size/batch_time.val, | |
| args.world_size*args.batch_size/batch_time.avg, | |
| batch_time=batch_time, | |
| loss=losses, top1=top1, top5=top5)) | |
| if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()") | |
| input, target = prefetcher.next() | |
| if args.prof >= 0: torch.cuda.nvtx.range_pop() | |
| # Pop range "Body of iteration {}".format(i) | |
| if args.prof >= 0: torch.cuda.nvtx.range_pop() | |
| if args.prof >= 0 and i == args.prof + 10: | |
| print("Profiling ended at iteration {}".format(i)) | |
| torch.cuda.cudart().cudaProfilerStop() | |
| quit() | |
| def validate(val_loader, model, criterion): | |
| batch_time = AverageMeter() | |
| losses = AverageMeter() | |
| top1 = AverageMeter() | |
| top5 = AverageMeter() | |
| # switch to evaluate mode | |
| model.eval() | |
| end = time.time() | |
| prefetcher = data_prefetcher(val_loader) | |
| input, target = prefetcher.next() | |
| i = 0 | |
| while input is not None: | |
| i += 1 | |
| # compute output | |
| with torch.no_grad(): | |
| output = model(input) | |
| loss = criterion(output, target) | |
| # measure accuracy and record loss | |
| prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | |
| if args.distributed: | |
| reduced_loss = reduce_tensor(loss.data) | |
| prec1 = reduce_tensor(prec1) | |
| prec5 = reduce_tensor(prec5) | |
| else: | |
| reduced_loss = loss.data | |
| losses.update(to_python_float(reduced_loss), input.size(0)) | |
| top1.update(to_python_float(prec1), input.size(0)) | |
| top5.update(to_python_float(prec5), input.size(0)) | |
| # measure elapsed time | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| # TODO: Change timings to mirror train(). | |
| if args.local_rank == 0 and i % args.print_freq == 0: | |
| print('Test: [{0}/{1}]\t' | |
| 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
| 'Speed {2:.3f} ({3:.3f})\t' | |
| 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
| 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | |
| 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | |
| i, len(val_loader), | |
| args.world_size * args.batch_size / batch_time.val, | |
| args.world_size * args.batch_size / batch_time.avg, | |
| batch_time=batch_time, loss=losses, | |
| top1=top1, top5=top5)) | |
| input, target = prefetcher.next() | |
| print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' | |
| .format(top1=top1, top5=top5)) | |
| return top1.avg | |
| def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): | |
| torch.save(state, filename) | |
| if is_best: | |
| shutil.copyfile(filename, 'model_best.pth.tar') | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def adjust_learning_rate(optimizer, epoch, step, len_epoch): | |
| """LR schedule that should yield 76% converged accuracy with batch size 256""" | |
| factor = epoch // 30 | |
| if epoch >= 80: | |
| factor = factor + 1 | |
| lr = args.lr*(0.1**factor) | |
| """Warmup""" | |
| if epoch < 5: | |
| lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) | |
| # if(args.local_rank == 0): | |
| # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| def accuracy(output, target, topk=(1,)): | |
| """Computes the precision@k for the specified values of k""" | |
| maxk = max(topk) | |
| batch_size = target.size(0) | |
| _, pred = output.topk(maxk, 1, True, True) | |
| pred = pred.t() | |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
| res = [] | |
| for k in topk: | |
| correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | |
| res.append(correct_k.mul_(100.0 / batch_size)) | |
| return res | |
| def reduce_tensor(tensor): | |
| rt = tensor.clone() | |
| dist.all_reduce(rt, op=dist.reduce_op.SUM) | |
| rt /= args.world_size | |
| return rt | |
| if __name__ == '__main__': | |
| main() | |