Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	| import os | |
| import cv2 | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from torchvision import datasets, transforms | |
| from torch.autograd import Variable | |
| from network_v0.model import PointModel | |
| from datasets.hp_loader import PatchesDataset | |
| from torch.utils.data import DataLoader | |
| from evaluation.evaluate import evaluate_keypoint_net | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Testing") | |
| parser.add_argument("--device", default=0, type=int, help="which gpu to run on.") | |
| parser.add_argument("--test_dir", required=True, type=str, help="Test data path.") | |
| opt = parser.parse_args() | |
| torch.manual_seed(0) | |
| use_gpu = torch.cuda.is_available() | |
| if use_gpu: | |
| torch.cuda.set_device(opt.device) | |
| # Load data in 320x240 | |
| hp_dataset_320x240 = PatchesDataset( | |
| root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type="all" | |
| ) | |
| data_loader_320x240 = DataLoader( | |
| hp_dataset_320x240, | |
| batch_size=1, | |
| pin_memory=False, | |
| shuffle=False, | |
| num_workers=4, | |
| worker_init_fn=None, | |
| sampler=None, | |
| ) | |
| # Load data in 640x480 | |
| hp_dataset_640x480 = PatchesDataset( | |
| root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type="all" | |
| ) | |
| data_loader_640x480 = DataLoader( | |
| hp_dataset_640x480, | |
| batch_size=1, | |
| pin_memory=False, | |
| shuffle=False, | |
| num_workers=4, | |
| worker_init_fn=None, | |
| sampler=None, | |
| ) | |
| # Load model | |
| model = PointModel(is_test=True) | |
| ckpt = torch.load("./checkpoints/PointModel_v0.pth") | |
| model.load_state_dict(ckpt["model_state"]) | |
| model = model.eval() | |
| if use_gpu: | |
| model = model.cuda() | |
| print("Evaluating in 320x240, 300 points") | |
| rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( | |
| data_loader_320x240, model, output_shape=(320, 240), top_k=300 | |
| ) | |
| print("Repeatability: {0:.3f}".format(rep)) | |
| print("Localization Error: {0:.3f}".format(loc)) | |
| print("H-1 Accuracy: {:.3f}".format(c1)) | |
| print("H-3 Accuracy: {:.3f}".format(c3)) | |
| print("H-5 Accuracy: {:.3f}".format(c5)) | |
| print("Matching Score: {:.3f}".format(mscore)) | |
| print("\n") | |
| print("Evaluating in 640x480, 1000 points") | |
| rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( | |
| data_loader_640x480, model, output_shape=(640, 480), top_k=1000 | |
| ) | |
| print("Repeatability: {0:.3f}".format(rep)) | |
| print("Localization Error: {0:.3f}".format(loc)) | |
| print("H-1 Accuracy: {:.3f}".format(c1)) | |
| print("H-3 Accuracy: {:.3f}".format(c3)) | |
| print("H-5 Accuracy: {:.3f}".format(c5)) | |
| print("Matching Score: {:.3f}".format(mscore)) | |
| print("\n") | |
| if __name__ == "__main__": | |
| main() | |
 
			
