Spaces:
Runtime error
Runtime error
| import argparse | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import sys | |
| from basicsr.utils import scandir | |
| from multiprocessing import Pool | |
| from os import path as osp | |
| from tqdm import tqdm | |
| def main(args): | |
| """A multi-thread tool to crop large images to sub-images for faster IO. | |
| opt (dict): Configuration dict. It contains: | |
| n_thread (int): Thread number. | |
| compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size | |
| and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2. | |
| input_folder (str): Path to the input folder. | |
| save_folder (str): Path to save folder. | |
| crop_size (int): Crop size. | |
| step (int): Step for overlapped sliding window. | |
| thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. | |
| Usage: | |
| For each folder, run this script. | |
| Typically, there are GT folder and LQ folder to be processed for DIV2K dataset. | |
| After process, each sub_folder should have the same number of subimages. | |
| Remember to modify opt configurations according to your settings. | |
| """ | |
| opt = {} | |
| opt['n_thread'] = args.n_thread | |
| opt['compression_level'] = args.compression_level | |
| opt['input_folder'] = args.input | |
| opt['save_folder'] = args.output | |
| opt['crop_size'] = args.crop_size | |
| opt['step'] = args.step | |
| opt['thresh_size'] = args.thresh_size | |
| extract_subimages(opt) | |
| def extract_subimages(opt): | |
| """Crop images to subimages. | |
| Args: | |
| opt (dict): Configuration dict. It contains: | |
| input_folder (str): Path to the input folder. | |
| save_folder (str): Path to save folder. | |
| n_thread (int): Thread number. | |
| """ | |
| input_folder = opt['input_folder'] | |
| save_folder = opt['save_folder'] | |
| if not osp.exists(save_folder): | |
| os.makedirs(save_folder) | |
| print(f'mkdir {save_folder} ...') | |
| else: | |
| print(f'Folder {save_folder} already exists. Exit.') | |
| sys.exit(1) | |
| # scan all images | |
| img_list = list(scandir(input_folder, full_path=True)) | |
| pbar = tqdm(total=len(img_list), unit='image', desc='Extract') | |
| pool = Pool(opt['n_thread']) | |
| for path in img_list: | |
| pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) | |
| pool.close() | |
| pool.join() | |
| pbar.close() | |
| print('All processes done.') | |
| def worker(path, opt): | |
| """Worker for each process. | |
| Args: | |
| path (str): Image path. | |
| opt (dict): Configuration dict. It contains: | |
| crop_size (int): Crop size. | |
| step (int): Step for overlapped sliding window. | |
| thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. | |
| save_folder (str): Path to save folder. | |
| compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. | |
| Returns: | |
| process_info (str): Process information displayed in progress bar. | |
| """ | |
| crop_size = opt['crop_size'] | |
| step = opt['step'] | |
| thresh_size = opt['thresh_size'] | |
| img_name, extension = osp.splitext(osp.basename(path)) | |
| # remove the x2, x3, x4 and x8 in the filename for DIV2K | |
| img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') | |
| img = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
| h, w = img.shape[0:2] | |
| h_space = np.arange(0, h - crop_size + 1, step) | |
| if h - (h_space[-1] + crop_size) > thresh_size: | |
| h_space = np.append(h_space, h - crop_size) | |
| w_space = np.arange(0, w - crop_size + 1, step) | |
| if w - (w_space[-1] + crop_size) > thresh_size: | |
| w_space = np.append(w_space, w - crop_size) | |
| index = 0 | |
| for x in h_space: | |
| for y in w_space: | |
| index += 1 | |
| cropped_img = img[x:x + crop_size, y:y + crop_size, ...] | |
| cropped_img = np.ascontiguousarray(cropped_img) | |
| cv2.imwrite( | |
| osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img, | |
| [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) | |
| process_info = f'Processing {img_name} ...' | |
| return process_info | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder') | |
| parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder') | |
| parser.add_argument('--crop_size', type=int, default=480, help='Crop size') | |
| parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window') | |
| parser.add_argument( | |
| '--thresh_size', | |
| type=int, | |
| default=0, | |
| help='Threshold size. Patches whose size is lower than thresh_size will be dropped.') | |
| parser.add_argument('--n_thread', type=int, default=20, help='Thread number.') | |
| parser.add_argument('--compression_level', type=int, default=3, help='Compression level') | |
| args = parser.parse_args() | |
| main(args) | |