Spaces:
Build error
Build error
| """Data manipulation helpers""" | |
| import os.path | |
| import pickle | |
| from cirtorch.datasets.datahelpers import cid2filename | |
| from cirtorch.datasets.testdataset import configdataset | |
| def load_dataset(dataset, data_root=''): | |
| """Return tuple (image list, query list, bounding boxes, gnd dictionary)""" | |
| if isinstance(dataset, dict): | |
| root = os.path.join(data_root, dataset['image_root']) | |
| images, qimages = None, None | |
| if dataset['database_list'] is not None: | |
| images = [path_join(root, x.strip("\n")) for x in open(dataset['database_list']).readlines()] | |
| if dataset['query_list'] is not None: | |
| qimages = [path_join(root, x.strip("\n")) for x in open(dataset['query_list']).readlines()] | |
| bbxs = None | |
| gnd = None | |
| elif dataset == 'train': | |
| training_set = 'retrieval-SfM-120k' | |
| db_root = os.path.join(data_root, 'train', training_set) | |
| ims_root = os.path.join(db_root, 'ims') | |
| db_fn = os.path.join(db_root, '{}.pkl'.format(training_set)) | |
| with open(db_fn, 'rb') as f: | |
| db = pickle.load(f)['train'] | |
| images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] | |
| qimages = [] | |
| bbxs = None | |
| gnd = None | |
| elif dataset == 'val_eccv20': | |
| db_root = os.path.join(data_root, 'train', 'retrieval-SfM-120k') | |
| fn_val_proper = db_root+'/retrieval-SfM-120k-val-eccv2020.pkl' # pos are all with #inl >=3 & <= 10 | |
| with open(fn_val_proper, 'rb') as f: | |
| db = pickle.load(f) | |
| ims_root = os.path.join(db_root, 'ims') | |
| images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] | |
| gnd = db['gnd'] | |
| qidx = db['qidx'] | |
| qimages = [images[x] for x in qidx] | |
| bbxs = None | |
| elif "/" in dataset: | |
| with open(dataset, 'rb') as handle: | |
| db = pickle.load(handle) | |
| images, qimages, bbxs, gnd = db['imlist'], db['qimlist'], None, db['gnd'] | |
| else: | |
| cfg = configdataset(dataset, os.path.join(data_root, 'test')) | |
| images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])] | |
| qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])] | |
| if 'bbx' in cfg['gnd'][0].keys(): | |
| bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])] | |
| else: | |
| bbxs = None | |
| gnd = cfg['gnd'] | |
| return images, qimages, bbxs, gnd | |
| def path_join(root, name): | |
| """Perform os.path.join by default; if asterisk is present in root, substitute with the name. | |
| >>> path_join('/data/img_*.jpg', '001') | |
| '/data/img_001.jpg' | |
| """ | |
| if "*" in root.rsplit("/", 1)[-1]: | |
| return root.replace("*", name) | |
| return os.path.join(root, name) | |
| class AverageMeter: | |
| """Compute and store the average and last value""" | |
| def __init__(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| """Update the counter by a new value""" | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |