Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import math | |
| import os.path as osp | |
| import mmcv | |
| import mmengine | |
| from mmocr.utils import dump_ocr_data | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description='Generate training and validation set of LSVT ') | |
| parser.add_argument('root_path', help='Root dir path of LSVT') | |
| parser.add_argument( | |
| '--val-ratio', help='Split ratio for val set', default=0.0, type=float) | |
| args = parser.parse_args() | |
| return args | |
| def collect_lsvt_info(root_path, split, ratio, print_every=1000): | |
| """Collect the annotation information. | |
| The annotation format is as the following: | |
| [ | |
| {'gt_1234': # 'gt_1234' is file name | |
| [ | |
| { | |
| 'transcription': '一站式购物中心', | |
| 'points': [[45, 272], [215, 273], [212, 296], [45, 290]] | |
| 'illegibility': False | |
| }, ... | |
| ] | |
| } | |
| ] | |
| Args: | |
| root_path (str): Root path to the dataset | |
| split (str): Dataset split, which should be 'train' or 'val' | |
| ratio (float): Split ratio for val set | |
| print_every (int): Print log info per iteration | |
| Returns: | |
| img_info (dict): The dict of the img and annotation information | |
| """ | |
| annotation_path = osp.join(root_path, 'annotations/train_full_labels.json') | |
| if not osp.exists(annotation_path): | |
| raise Exception( | |
| f'{annotation_path} not exists, please check and try again.') | |
| annotation = mmengine.load(annotation_path) | |
| img_prefixes = annotation.keys() | |
| trn_files, val_files = [], [] | |
| if ratio > 0: | |
| for i, file in enumerate(img_prefixes): | |
| if i % math.floor(1 / ratio): | |
| trn_files.append(file) | |
| else: | |
| val_files.append(file) | |
| else: | |
| trn_files, val_files = img_prefixes, [] | |
| print(f'training #{len(trn_files)}, val #{len(val_files)}') | |
| if split == 'train': | |
| img_prefixes = trn_files | |
| elif split == 'val': | |
| img_prefixes = val_files | |
| else: | |
| raise NotImplementedError | |
| img_infos = [] | |
| for i, prefix in enumerate(img_prefixes): | |
| if i > 0 and i % print_every == 0: | |
| print(f'{i}/{len(img_prefixes)}') | |
| img_file = osp.join(root_path, 'imgs', prefix + '.jpg') | |
| # Skip not exist images | |
| if not osp.exists(img_file): | |
| continue | |
| img = mmcv.imread(img_file) | |
| img_info = dict( | |
| file_name=osp.join(osp.basename(img_file)), | |
| height=img.shape[0], | |
| width=img.shape[1], | |
| segm_file=osp.join(osp.basename(annotation_path))) | |
| anno_info = [] | |
| for ann in annotation[prefix]: | |
| segmentation = [] | |
| for x, y in ann['points']: | |
| segmentation.append(max(0, x)) | |
| segmentation.append(max(0, y)) | |
| xs, ys = segmentation[::2], segmentation[1::2] | |
| x, y = min(xs), min(ys) | |
| w, h = max(xs) - x, max(ys) - y | |
| bbox = [x, y, w, h] | |
| anno = dict( | |
| iscrowd=1 if ann['illegibility'] else 0, | |
| category_id=1, | |
| bbox=bbox, | |
| area=w * h, | |
| segmentation=[segmentation]) | |
| anno_info.append(anno) | |
| img_info.update(anno_info=anno_info) | |
| img_infos.append(img_info) | |
| return img_infos | |
| def main(): | |
| args = parse_args() | |
| root_path = args.root_path | |
| print('Processing training set...') | |
| training_infos = collect_lsvt_info(root_path, 'train', args.val_ratio) | |
| dump_ocr_data(training_infos, | |
| osp.join(root_path, 'instances_training.json'), 'textdet') | |
| if args.val_ratio > 0: | |
| print('Processing validation set...') | |
| val_infos = collect_lsvt_info(root_path, 'val', args.val_ratio) | |
| dump_ocr_data(val_infos, osp.join(root_path, 'instances_val.json'), | |
| 'textdet') | |
| print('Finish') | |
| if __name__ == '__main__': | |
| main() | |