Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from collections import OrderedDict | |
| import json | |
| import math | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import sys | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torch.backends.cudnn as cudnn | |
| import torch.cuda.amp as amp | |
| from torch.distributed.optim import ZeroRedundancyOptimizer | |
| import torch.nn.parallel | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms._transforms_video as transforms_video | |
| from sklearn.metrics import confusion_matrix | |
| import wandb | |
| from lavila.data import datasets | |
| from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop | |
| from lavila.models import models | |
| from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) | |
| from lavila.models.utils import inflate_positional_embeds | |
| from lavila.utils import distributed as dist_utils | |
| from lavila.utils.evaluation import accuracy, get_mean_accuracy | |
| from lavila.utils.meter import AverageMeter, ProgressMeter | |
| from lavila.utils.preprocess import generate_label_map | |
| from lavila.utils.random import random_seed | |
| from lavila.utils.scheduler import cosine_scheduler | |
| from lavila.utils.evaluation_ek100cls import get_marginal_indexes, marginalize | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser(description='lavila finetune and evaluation', add_help=False) | |
| # Data | |
| parser.add_argument('--dataset', default='ek100_cls', type=str, | |
| choices=['ek100_cls', 'egtea']) | |
| parser.add_argument('--root', | |
| default='datasets/EK100/video_ht256px/', | |
| type=str, help='path to dataset root') | |
| parser.add_argument('--metadata-train', | |
| default='datasets/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv', | |
| type=str, help='path to metadata file (train set)') | |
| parser.add_argument('--metadata-val', | |
| default='datasets/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv', | |
| type=str, help='path to metadata file (val set)') | |
| parser.add_argument('--relevancy-path', | |
| default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl', | |
| type=str, help='path to relevancy matrix (val set)') | |
| parser.add_argument('--output-dir', default='./', type=str, help='output dir') | |
| parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms for val') | |
| parser.add_argument('--num-clips', default=1, type=int, help='number of clips for val') | |
| parser.add_argument('--clip-length', default=16, type=int, help='clip length') | |
| parser.add_argument('--clip-stride', default=2, type=int, help='clip stride') | |
| parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling') | |
| # Model | |
| parser.add_argument('--pretrain-model', default='', type=str, help='path to pretrain model') | |
| parser.add_argument('--resume', default='', type=str, help='path to resume from') | |
| parser.add_argument('--find-unused-parameters', action='store_true', | |
| help='do this during DDP (useful for models with tied weights)') | |
| parser.add_argument('--drop-path-rate', default=0.1, type=float, help='drop path ratio') | |
| parser.add_argument('--dropout-ratio', default=0.5, type=float, help='dropout ratio for the last linear layer') | |
| parser.add_argument('--num-classes', default=3806, nargs='+', type=int, help='number of classes for the last linear layer') | |
| parser.add_argument('--use-vn-classifier', action='store_true') | |
| parser.add_argument('--use-half', action='store_true', help='use half precision at inference') | |
| # Training | |
| parser.add_argument('--epochs', default=100, type=int) | |
| parser.add_argument('--warmup-epochs', default=1, type=int) | |
| parser.add_argument('--start-epoch', default=0, type=int) | |
| parser.add_argument('--batch-size', default=16, type=int, | |
| help='number of samples per-device/per-gpu') | |
| parser.add_argument('--use-sgd', action='store_true') | |
| parser.add_argument('--freeze-temperature', action='store_true', help='freeze temperature if set to True') | |
| parser.add_argument('--lr', default=3e-3, type=float) | |
| parser.add_argument('--fix-lr', action='store_true', help='disable cosine lr decay if set True') | |
| parser.add_argument('--lr-start', default=1e-6, type=float, | |
| help='initial warmup lr') | |
| parser.add_argument('--lr-end', default=1e-5, type=float, | |
| help='minimum final lr') | |
| parser.add_argument('--lr-multiplier-on-backbone', default=0.1, type=float, help='lr multiplier for the backbone') | |
| parser.add_argument('--clip-grad-type', default='norm', choices=['norm', 'value']) | |
| parser.add_argument('--clip-grad-value', default=None, type=float, help='') | |
| parser.add_argument('--update-freq', default=1, type=int, | |
| help='optimizer update frequency (i.e. gradient accumulation steps)') | |
| parser.add_argument('--wd', default=0.01, type=float) | |
| parser.add_argument('--betas', default=(0.9, 0.999), nargs=2, type=float) | |
| parser.add_argument('--eps', default=1e-8, type=float) | |
| parser.add_argument('--label-smoothing', default=0.1, type=float, help='label smoothing') | |
| parser.add_argument('--eval-freq', default=5, type=int) | |
| parser.add_argument('--save-freq', default=5, type=int) | |
| parser.add_argument('--disable-amp', action='store_true', | |
| help='disable mixed-precision training (requires more memory and compute)') | |
| parser.add_argument('--use-zero', action='store_true', | |
| help='use ZeroRedundancyOptimizer to save memory') | |
| parser.add_argument('--use-checkpoint', action='store_true', | |
| help='use gradient checkpointing during training for significantly less GPU usage') | |
| # System | |
| parser.add_argument('--print-freq', default=100, type=int, help='print frequency') | |
| parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |
| help='number of data loading workers per process') | |
| parser.add_argument('--evaluate', action='store_true', help='eval only') | |
| parser.add_argument('--world-size', default=1, type=int, | |
| help='number of nodes for distributed training') | |
| parser.add_argument('--rank', default=0, type=int, | |
| help='node rank for distributed training') | |
| parser.add_argument("--local_rank", type=int, default=0) | |
| parser.add_argument('--dist-url', default='env://', type=str, | |
| help='url used to set up distributed training') | |
| parser.add_argument('--dist-backend', default='nccl', type=str) | |
| parser.add_argument('--seed', default=0, type=int) | |
| parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') | |
| parser.add_argument('--wandb', action='store_true', help='Enable WandB logging') | |
| return parser | |
| def main(args): | |
| dist_utils.init_distributed_mode(args) | |
| global best_acc1 | |
| random_seed(args.seed, dist_utils.get_rank()) | |
| if args.pretrain_model: | |
| ckpt_path = args.pretrain_model | |
| else: | |
| raise Exception('no checkpoint found') | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| if args.use_vn_classifier: | |
| assert args.dataset == 'ek100_cls' and len(args.num_classes) == 3 | |
| state_dict = OrderedDict() | |
| for k, v in ckpt['state_dict'].items(): | |
| state_dict[k.replace('module.', '')] = v | |
| old_args = ckpt['args'] | |
| print("=> creating model: {}".format(old_args.model)) | |
| model = getattr(models, old_args.model)( | |
| pretrained=old_args.load_visual_pretrained, | |
| pretrained2d=old_args.load_visual_pretrained is not None, | |
| text_use_cls_token=old_args.use_cls_token, | |
| project_embed_dim=old_args.project_embed_dim, | |
| timesformer_gated_xattn=False, | |
| timesformer_freeze_space=False, | |
| num_frames=args.clip_length, | |
| drop_path_rate=args.drop_path_rate, | |
| ) | |
| if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model: | |
| # inflate weight | |
| print('=> inflating PE in models due to different frame numbers') | |
| state_dict = inflate_positional_embeds( | |
| model.state_dict(), state_dict, | |
| num_frames=args.clip_length, | |
| load_temporal_fix='bilinear', | |
| ) | |
| model.load_state_dict(state_dict, strict=True) | |
| print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) | |
| if args.use_vn_classifier: | |
| model = models.VideoClassifierMultiHead( | |
| model.visual, | |
| dropout=args.dropout_ratio, | |
| num_classes_list=args.num_classes | |
| ) | |
| else: | |
| assert len(args.num_classes) == 1 | |
| model = models.VideoClassifier( | |
| model.visual, | |
| dropout=args.dropout_ratio, | |
| num_classes=args.num_classes[0] | |
| ) | |
| model.cuda(args.gpu) | |
| if args.distributed: | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, device_ids=[args.gpu], bucket_cap_mb=200, | |
| find_unused_parameters=args.find_unused_parameters | |
| ) | |
| p_wd, p_non_wd = [], [] | |
| p_head_wd, p_head_non_wd = [], [] | |
| for n, p in model.named_parameters(): | |
| if 'fc_cls' in n: | |
| if 'bias' in n: | |
| p_head_non_wd.append(p) | |
| else: | |
| p_head_wd.append(p) | |
| elif not p.requires_grad: | |
| continue # frozen weights | |
| elif p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: | |
| p_non_wd.append(p) | |
| else: | |
| p_wd.append(p) | |
| optim_params = [ | |
| {"params": p_wd, "weight_decay": args.wd, "lr": args.lr * args.lr_multiplier_on_backbone}, | |
| {"params": p_non_wd, "weight_decay": 0, "lr": args.lr * args.lr_multiplier_on_backbone}, | |
| {"params": p_head_wd, "weight_decay": args.wd}, | |
| {"params": p_head_non_wd, "weight_decay": 0} | |
| ] | |
| if args.use_zero: | |
| optimizer = ZeroRedundancyOptimizer( | |
| optim_params, optimizer_class=torch.optim.SGD if args.use_sgd else torch.optim.AdamW, | |
| lr=args.lr, betas=args.betas, eps=args.eps, weight_decay=args.wd | |
| ) | |
| else: | |
| if args.use_sgd: | |
| optimizer = torch.optim.SGD(optim_params, lr=args.lr, momentum=args.betas[0], weight_decay=args.wd) | |
| else: | |
| optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas, | |
| eps=args.eps, weight_decay=args.wd) | |
| scaler = amp.GradScaler(enabled=not args.disable_amp) | |
| # optionally resume from a checkpoint (takes precedence over autoresume) | |
| latest = os.path.join(args.output_dir, 'checkpoint.pt') | |
| if os.path.isfile(latest): | |
| args.resume = '' | |
| if args.resume: | |
| if os.path.isfile(args.resume): | |
| print("=> loading resume checkpoint '{}'".format(args.resume)) | |
| checkpoint = torch.load(args.resume, map_location='cpu') | |
| epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 | |
| args.start_epoch = epoch | |
| if not args.distributed: | |
| state_dict = OrderedDict() | |
| for k, v in checkpoint['state_dict'].items(): | |
| state_dict[k.replace('module.', '')] = v | |
| result = model.load_state_dict(state_dict, strict=False) | |
| else: | |
| result = model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| print(result) | |
| optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else () | |
| scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else () | |
| best_acc1 = checkpoint['best_acc1'] | |
| print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})" | |
| .format(args.resume, epoch, best_acc1)) | |
| else: | |
| print("=> no checkpoint found at '{}'".format(args.resume)) | |
| else: | |
| # auto-resume from latest checkpoint in output directory | |
| latest = os.path.join(args.output_dir, 'checkpoint.pt') | |
| if os.path.isfile(latest): | |
| print("=> loading latest checkpoint '{}'".format(latest)) | |
| latest_checkpoint = torch.load(latest, map_location='cpu') | |
| args.start_epoch = latest_checkpoint['epoch'] | |
| model.load_state_dict(latest_checkpoint['state_dict']) | |
| optimizer.load_state_dict(latest_checkpoint['optimizer']) | |
| scaler.load_state_dict(latest_checkpoint['scaler']) | |
| best_acc1 = latest_checkpoint['best_acc1'] | |
| print("=> loaded latest checkpoint '{}' (epoch {})" | |
| .format(latest, latest_checkpoint['epoch'])) | |
| cudnn.benchmark = True | |
| # Data loading code | |
| print("=> creating dataset") | |
| if old_args.model.endswith('DISTILBERT_BASE'): | |
| tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') | |
| elif old_args.model.endswith('BERT_BASE'): | |
| tokenizer = MyBertTokenizer('bert-base-uncased') | |
| elif old_args.model.endswith('BERT_LARGE'): | |
| tokenizer = MyBertTokenizer('bert-large-uncased') | |
| elif old_args.model.endswith('GPT2'): | |
| tokenizer = MyGPT2Tokenizer('gpt2') | |
| elif old_args.model.endswith('GPT2_MEDIUM'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-medium') | |
| elif old_args.model.endswith('GPT2_LARGE'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-large') | |
| elif old_args.model.endswith('GPT2_XL'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-xl') | |
| else: | |
| print("Using SimpleTokenizer because of model '{}'. " | |
| "Please check if this is what you want".format(old_args.model)) | |
| tokenizer = SimpleTokenizer() | |
| criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).cuda(args.gpu) | |
| crop_size = 224 if '336PX' not in old_args.model else 336 | |
| transforms_list = [ | |
| Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
| transforms.RandomResizedCrop(crop_size, scale=(0.5, 1.0)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| ] | |
| if 'OPENAI' in old_args.model: | |
| transforms_list.append(transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])) | |
| else: | |
| transforms_list.append(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])) | |
| train_transform = transforms.Compose(transforms_list) | |
| val_transform = transforms.Compose([ | |
| Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
| transforms.Resize(crop_size), | |
| transforms.CenterCrop(crop_size), | |
| (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if 'OPENAI' not in old_args.model else | |
| transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), | |
| TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length), | |
| SpatialCrop(crop_size=crop_size, num_crops=args.num_crops), | |
| ]) | |
| # build dataset | |
| _, mapping_vn2act = generate_label_map(args.dataset) | |
| if args.dataset == 'ek100_cls': | |
| args.mapping_act2v = {i: int(vn.split(':')[0]) for (vn, i) in mapping_vn2act.items()} | |
| args.mapping_act2n = {i: int(vn.split(':')[1]) for (vn, i) in mapping_vn2act.items()} | |
| args.actions = pd.DataFrame.from_dict({'verb': args.mapping_act2v.values(), 'noun': args.mapping_act2n.values()}) | |
| num_clips_at_val = args.num_clips | |
| args.num_clips = 1 | |
| train_dataset = datasets.get_downstream_dataset( | |
| train_transform, tokenizer, args, subset='train', label_mapping=mapping_vn2act, | |
| ) | |
| args.num_clips = num_clips_at_val | |
| val_dataset = datasets.get_downstream_dataset( | |
| val_transform, tokenizer, args, subset='val', label_mapping=mapping_vn2act, | |
| ) | |
| if args.distributed: | |
| train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
| val_sampler = torch.utils.data.SequentialSampler(val_dataset) # disable distributed | |
| else: | |
| train_sampler = None | |
| val_sampler = None | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), | |
| num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True | |
| ) | |
| print('len(train_loader) = {}'.format(len(train_loader))) | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None), | |
| num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False | |
| ) | |
| print('len(val_loader) = {}'.format(len(val_loader))) | |
| if args.evaluate: | |
| if args.use_vn_classifier: | |
| val_stats = validate_multihead(val_loader, model, args) | |
| else: | |
| val_stats = validate(val_loader, model, args) | |
| return | |
| if args.fix_lr: | |
| lr_schedule = None | |
| else: | |
| lr_schedule = cosine_scheduler( | |
| args.lr, args.lr_end, args.epochs, len(train_loader) // args.update_freq, | |
| warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start, | |
| ) | |
| if dist_utils.is_main_process() and args.wandb: | |
| wandb_id = os.path.split(args.output_dir)[-1] | |
| wandb.init(project='LaViLa', id=wandb_id, config=args, resume='allow') | |
| print(args) | |
| best_metric = 0. | |
| print("=> beginning training") | |
| for epoch in range(args.start_epoch, args.epochs): | |
| if args.distributed: | |
| train_sampler.set_epoch(epoch) | |
| train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args) | |
| is_epoch = ((epoch + 1) % args.save_freq) == 0 | |
| print('=> saving checkpoint') | |
| dist_utils.save_on_master({ | |
| 'epoch': epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'scaler': scaler.state_dict(), | |
| 'best_acc1': 0, | |
| 'args': args, | |
| }, False, args.output_dir, is_epoch=is_epoch) | |
| if ((epoch + 1) % args.eval_freq) == 0: | |
| if args.use_vn_classifier: | |
| val_stats = validate_multihead(val_loader, model, args) | |
| else: | |
| val_stats = validate(val_loader, model, args) | |
| if val_stats['acc1'] > best_metric: | |
| is_best = True | |
| best_metric = val_stats['acc1'] | |
| else: | |
| is_best = False | |
| print('=> saving checkpoint') | |
| dist_utils.save_on_master({ | |
| 'epoch': epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'scaler': scaler.state_dict(), | |
| 'best_acc1': best_metric, | |
| 'args': args, | |
| }, is_best, args.output_dir, is_epoch=is_epoch) | |
| log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, | |
| **{f'test_{k}': v for k, v in val_stats.items()}, | |
| 'epoch': epoch} | |
| if dist_utils.is_main_process(): | |
| if args.wandb: | |
| wandb.log(log_stats) | |
| with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f: | |
| f.write(json.dumps(log_stats) + '\n') | |
| def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args): | |
| batch_time = AverageMeter('Time', ':6.2f') | |
| data_time = AverageMeter('Data', ':6.2f') | |
| mem = AverageMeter('Mem (GB)', ':6.1f') | |
| iters_per_epoch = len(train_loader) // args.update_freq | |
| losses = AverageMeter('Loss', ':.4e') | |
| top1 = AverageMeter('Acc@1', ':6.2f') | |
| top5 = AverageMeter('Acc@5', ':6.2f') | |
| top1_noun = AverageMeter('Noun Acc@1', ':6.2f') | |
| top1_verb = AverageMeter('Verb Acc@1', ':6.2f') | |
| progress = ProgressMeter( | |
| iters_per_epoch, | |
| [batch_time, data_time, mem, losses, top1, top5, top1_noun, top1_verb], | |
| prefix="Epoch: [{}]".format(epoch)) | |
| # switch to train mode | |
| model.train() | |
| end = time.time() | |
| for data_iter, (images, target) in enumerate(train_loader): | |
| optim_iter = data_iter // args.update_freq | |
| # measure data loading time | |
| data_time.update(time.time() - end) | |
| # update weight decay and learning rate according to their schedule | |
| it = iters_per_epoch * epoch + optim_iter # global training iteration | |
| for k, param_group in enumerate(optimizer.param_groups): | |
| if lr_schedule is not None: | |
| param_group['lr'] = lr_schedule[it] * args.lr_multiplier_on_backbone | |
| else: | |
| param_group['lr'] = lr_schedule[it] | |
| images = images.cuda(args.gpu, non_blocking=True) | |
| target = target.cuda(args.gpu, non_blocking=True) | |
| # compute output | |
| with amp.autocast(enabled=not args.disable_amp): | |
| output = model(images, use_checkpoint=args.use_checkpoint) | |
| if isinstance(output, list): | |
| assert len(output) == 3 | |
| target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| loss = criterion(output[0], target_to_verb) | |
| target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| loss += criterion(output[1], target_to_noun) | |
| loss += criterion(output[2], target) | |
| else: | |
| loss = criterion(output, target) | |
| loss /= args.update_freq | |
| if not math.isfinite(loss.item()): | |
| print("Loss is {}, stopping training".format(loss.item())) | |
| sys.exit(1) | |
| scaler.scale(loss).backward() | |
| if (data_iter + 1) % args.update_freq != 0: | |
| continue | |
| if args.clip_grad_value is not None: | |
| scaler.unscale_(optimizer) | |
| if args.clip_grad_type == 'norm': | |
| torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), args.clip_grad_value, norm_type=2. | |
| ) | |
| elif args.clip_grad_type == 'value': | |
| torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_value) | |
| else: | |
| assert False, f"Unknown clip mode ({args.clip_grad_type})." | |
| # compute gradient and do SGD step | |
| scaler.step(optimizer) | |
| scaler.update() | |
| model.zero_grad(set_to_none=True) | |
| if isinstance(output, list): | |
| target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| acc1_verb, _ = accuracy(output[0], target_to_verb, topk=(1, 5)) | |
| top1_verb.update(acc1_verb.item(), images.size(0)) | |
| target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| acc1_noun, _ = accuracy(output[1], target_to_noun, topk=(1, 5)) | |
| top1_noun.update(acc1_noun.item(), images.size(0)) | |
| acc1, acc5 = accuracy(output[2], target, topk=(1, 5)) | |
| losses.update(loss.item(), images.size(0)) | |
| top1.update(acc1.item(), images.size(0)) | |
| top5.update(acc5.item(), images.size(0)) | |
| else: | |
| output = torch.softmax(output, dim=1) | |
| acc1, acc5 = accuracy(output, target, topk=(1, 5)) | |
| losses.update(loss.item(), images.size(0)) | |
| top1.update(acc1.item(), images.size(0)) | |
| top5.update(acc5.item(), images.size(0)) | |
| if args.dataset == 'ek100_cls': | |
| vi = get_marginal_indexes(args.actions, 'verb') | |
| ni = get_marginal_indexes(args.actions, 'noun') | |
| verb_scores = torch.tensor(marginalize(output.detach().cpu().numpy(), vi)).cuda(args.gpu, non_blocking=True) | |
| noun_scores = torch.tensor(marginalize(output.detach().cpu().numpy(), ni)).cuda(args.gpu, non_blocking=True) | |
| target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| acc1_verb, _ = accuracy(verb_scores, target_to_verb, topk=(1, 5)) | |
| acc1_noun, _ = accuracy(noun_scores, target_to_noun, topk=(1, 5)) | |
| top1_verb.update(acc1_verb.item(), images.size(0)) | |
| top1_noun.update(acc1_noun.item(), images.size(0)) | |
| else: | |
| top1_verb.update(0., images.size(0)) | |
| top1_noun.update(0., images.size(0)) | |
| # measure elapsed time | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| mem.update(torch.cuda.max_memory_allocated() // 1e9) | |
| if optim_iter % args.print_freq == 0: | |
| if dist_utils.is_main_process() and args.wandb: | |
| wandb.log({ | |
| 'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg, | |
| 'acc1_verb': top1_verb.avg, 'acc1_noun': top1_noun.avg, | |
| }) | |
| progress.display(optim_iter) | |
| progress.synchronize() | |
| return { | |
| 'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg, | |
| 'acc1_verb': top1_verb.avg, 'acc1_noun': top1_noun.avg, | |
| 'lr': optimizer.param_groups[0]['lr'], | |
| } | |
| def validate(val_loader, model, args): | |
| batch_time = AverageMeter('Time', ':6.2f') | |
| data_time = AverageMeter('Data', ':6.2f') | |
| top1 = AverageMeter('Acc@1', ':6.2f') | |
| top5 = AverageMeter('Acc@5', ':6.2f') | |
| progress = ProgressMeter( | |
| len(val_loader), | |
| [batch_time, top1, top5], | |
| prefix='Test: ' | |
| ) | |
| # switch to eval mode | |
| model.eval() | |
| if args.use_half: | |
| model.half() | |
| all_outputs = [] | |
| all_targets = [] | |
| with torch.no_grad(): | |
| end = time.time() | |
| for i, (images, target) in enumerate(val_loader): | |
| # measure data loading time | |
| data_time.update(time.time() - end) | |
| if isinstance(images, list): | |
| logit_allcrops = [] | |
| for crop in images: | |
| crop = crop.cuda(args.gpu, non_blocking=True) | |
| if args.use_half: | |
| crop = crop.half() | |
| logit = model(crop, use_checkpoint=args.use_checkpoint) | |
| logit_allcrops.append(logit) | |
| logit_allcrops = torch.stack(logit_allcrops, 0) | |
| logit = logit_allcrops.mean(0) | |
| logit = torch.softmax(logit, dim=1) | |
| target = target.cuda(args.gpu, non_blocking=True) | |
| acc1, acc5 = accuracy(logit, target, topk=(1, 5)) | |
| top1.update(acc1.item(), target.size(0)) | |
| top5.update(acc5.item(), target.size(0)) | |
| else: | |
| images = images.cuda(args.gpu, non_blocking=True) | |
| target = target.cuda(args.gpu, non_blocking=True) | |
| if args.use_half: | |
| images = images.half() | |
| logit = model(images, use_checkpoint=args.use_checkpoint) | |
| logit = torch.softmax(logit, dim=1) | |
| acc1, acc5 = accuracy(logit, target, topk=(1, 5)) | |
| top1.update(acc1.item(), images.size(0)) | |
| top5.update(acc5.item(), images.size(0)) | |
| all_outputs.append(logit) | |
| all_targets.append(target) | |
| # measure elapsed time | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| if i % args.print_freq == 0: | |
| progress.display(i) | |
| progress.synchronize() | |
| if args.dataset == 'ek100_cls': | |
| print('EK100 * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) | |
| else: | |
| print('EGTEA * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) | |
| all_outputs = torch.cat(all_outputs).cpu().numpy() | |
| all_targets = torch.cat(all_targets).cpu().numpy() | |
| cm = confusion_matrix(all_targets, all_outputs.argmax(axis=1)) | |
| mean_acc, acc = get_mean_accuracy(cm) | |
| print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_acc, acc)) | |
| if args.dataset == 'ek100_cls': | |
| vi = get_marginal_indexes(args.actions, 'verb') | |
| ni = get_marginal_indexes(args.actions, 'noun') | |
| verb_scores = marginalize(all_outputs, vi) | |
| noun_scores = marginalize(all_outputs, ni) | |
| target_to_verb = np.array([args.mapping_act2v[a] for a in all_targets.tolist()]) | |
| target_to_noun = np.array([args.mapping_act2n[a] for a in all_targets.tolist()]) | |
| cm = confusion_matrix(target_to_verb, verb_scores.argmax(axis=1)) | |
| _, acc = get_mean_accuracy(cm) | |
| print('Verb Acc@1: {:.3f}'.format(acc)) | |
| cm = confusion_matrix(target_to_noun, noun_scores.argmax(axis=1)) | |
| _, acc = get_mean_accuracy(cm) | |
| print('Noun Acc@1: {:.3f}'.format(acc)) | |
| return {'acc1': top1.avg, 'acc5': top5.avg, 'mean_acc': mean_acc} | |
| def validate_multihead(val_loader, model, args): | |
| batch_time = AverageMeter('Time', ':6.2f') | |
| data_time = AverageMeter('Data', ':6.2f') | |
| top1 = AverageMeter('Acc@1', ':6.2f') | |
| top5 = AverageMeter('Acc@5', ':6.2f') | |
| top1_verb = AverageMeter('Verb Acc@1', ':6.2f') | |
| top1_noun = AverageMeter('Noun Acc@1', ':6.2f') | |
| progress = ProgressMeter( | |
| len(val_loader), | |
| [batch_time, top1, top5, top1_verb, top1_noun], | |
| prefix='Test: ' | |
| ) | |
| # switch to eval mode | |
| model.eval() | |
| if args.use_half: | |
| model.half() | |
| all_verb_outputs = [] | |
| all_noun_outputs = [] | |
| all_action_outputs = [] | |
| all_verb_targets = [] | |
| all_noun_targets = [] | |
| all_action_targets = [] | |
| with torch.no_grad(): | |
| end = time.time() | |
| for i, (images, target) in enumerate(val_loader): | |
| # measure data loading time | |
| data_time.update(time.time() - end) | |
| if isinstance(images, torch.Tensor): | |
| images = [images, ] | |
| logit_verb_allcrops = [] | |
| logit_noun_allcrops = [] | |
| logit_action_allcrops = [] | |
| for crop in images: | |
| crop = crop.cuda(args.gpu, non_blocking=True) | |
| if args.use_half: | |
| crop = crop.half() | |
| logit = model(crop, use_checkpoint=args.use_checkpoint) | |
| logit_verb_allcrops.append(logit[0]) | |
| logit_noun_allcrops.append(logit[1]) | |
| logit_action_allcrops.append(logit[2]) | |
| logit_verb_allcrops = torch.stack(logit_verb_allcrops, 0) | |
| logit_noun_allcrops = torch.stack(logit_noun_allcrops, 0) | |
| logit_action_allcrops = torch.stack(logit_action_allcrops, 0) | |
| logit_verb = logit_verb_allcrops.mean(0) | |
| logit_noun = logit_noun_allcrops.mean(0) | |
| logit_action = logit_action_allcrops.mean(0) | |
| logit_noun = torch.softmax(logit_noun, dim=1) | |
| logit_verb = torch.softmax(logit_verb, dim=1) | |
| logit_action = torch.softmax(logit_action, dim=1) | |
| target = target.cuda(args.gpu, non_blocking=True) | |
| target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) | |
| acc1, acc5 = accuracy(logit_action, target, topk=(1, 5)) | |
| acc1_verb, _ = accuracy(logit_verb, target_to_verb, topk=(1, 5)) | |
| acc1_noun, _ = accuracy(logit_noun, target_to_noun, topk=(1, 5)) | |
| top1.update(acc1.item(), target.size(0)) | |
| top5.update(acc5.item(), target.size(0)) | |
| top1_verb.update(acc1_verb.item(), target_to_verb.size(0)) | |
| top1_noun.update(acc1_noun.item(), target_to_noun.size(0)) | |
| all_verb_outputs.append(logit_verb) | |
| all_noun_outputs.append(logit_noun) | |
| all_action_outputs.append(logit_action) | |
| all_verb_targets.append(target_to_verb) | |
| all_noun_targets.append(target_to_noun) | |
| all_action_targets.append(target) | |
| # measure elapsed time | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| if i % args.print_freq == 0: | |
| progress.display(i) | |
| progress.synchronize() | |
| print('EK100 * Verb Acc@1 {top1.avg:.3f}'.format(top1=top1_verb)) | |
| print('EK100 * Noun Acc@1 {top1.avg:.3f}'.format(top1=top1_noun)) | |
| print('EK100 * Action Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) | |
| return {'acc1': top1.avg, 'acc5': top5.avg, 'acc1_verb': top1_verb.avg, 'acc1_noun': top1_noun.avg} | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('lavila finetune and evaluation', parents=[get_args_parser()]) | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| main(args) | |