Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------ | |
# Deformable DETR | |
# Copyright (c) 2020 SenseTime. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
# ------------------------------------------------------------------------ | |
""" | |
Benchmark inference speed of Deformable DETR. | |
""" | |
import os | |
import time | |
import argparse | |
import torch | |
from main import get_args_parser as get_main_args_parser | |
from models import build_model | |
from datasets import build_dataset | |
from util.misc import nested_tensor_from_tensor_list | |
def get_benckmark_arg_parser(): | |
parser = argparse.ArgumentParser('Benchmark inference speed of Deformable DETR.') | |
parser.add_argument('--num_iters', type=int, default=300, help='total iters to benchmark speed') | |
parser.add_argument('--warm_iters', type=int, default=5, help='ignore first several iters that are very slow') | |
parser.add_argument('--batch_size', type=int, default=1, help='batch size in inference') | |
parser.add_argument('--resume', type=str, help='load the pre-trained checkpoint') | |
return parser | |
def measure_average_inference_time(model, inputs, num_iters=100, warm_iters=5): | |
ts = [] | |
for iter_ in range(num_iters): | |
torch.cuda.synchronize() | |
t_ = time.perf_counter() | |
model(inputs) | |
torch.cuda.synchronize() | |
t = time.perf_counter() - t_ | |
if iter_ >= warm_iters: | |
ts.append(t) | |
print(ts) | |
return sum(ts) / len(ts) | |
def benchmark(): | |
args, _ = get_benckmark_arg_parser().parse_known_args() | |
main_args = get_main_args_parser().parse_args(_) | |
assert args.warm_iters < args.num_iters and args.num_iters > 0 and args.warm_iters >= 0 | |
assert args.batch_size > 0 | |
assert args.resume is None or os.path.exists(args.resume) | |
dataset = build_dataset('val', main_args) | |
model, _, _ = build_model(main_args) | |
model.cuda() | |
model.eval() | |
if args.resume is not None: | |
ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage) | |
model.load_state_dict(ckpt['model']) | |
inputs = nested_tensor_from_tensor_list([dataset.__getitem__(0)[0].cuda() for _ in range(args.batch_size)]) | |
t = measure_average_inference_time(model, inputs, args.num_iters, args.warm_iters) | |
return 1.0 / t * args.batch_size | |
if __name__ == '__main__': | |
fps = benchmark() | |
print(f'Inference Speed: {fps:.1f} FPS') | |