Spaces:
Build error
Build error
| import numpy as np | |
| import os | |
| import matplotlib.image as mpimage | |
| import argparse | |
| import functools | |
| from utils import add_arguments, print_arguments | |
| from dask.distributed import LocalCluster | |
| from dask import bag as dbag | |
| from dask.diagnostics import ProgressBar | |
| from typing import Tuple | |
| from PIL import Image | |
| # Dataset statistics that I gathered in development | |
| #-----------------------------------# | |
| # 用于过滤感知质量较低的不良图片 | |
| #-----------------------------------# | |
| IMAGE_MEAN = 0.5 | |
| IMAGE_MEAN_STD = 0.028 | |
| IMG_STD = 0.28 | |
| IMG_STD_STD = 0.01 | |
| def readImage(fileName: str) -> np.ndarray: | |
| image = mpimage.imread(fileName) | |
| return image | |
| #-----------------------------------# | |
| # 从文件名中提取车牌的坐标 | |
| #-----------------------------------# | |
| def parseLabel(label: str) -> Tuple[np.ndarray, np.ndarray]: | |
| annotation = label.split('-')[3].split('_') | |
| coor1 = [int(i) for i in annotation[0].split('&')] | |
| coor2 = [int(i) for i in annotation[1].split('&')] | |
| coor3 = [int(i) for i in annotation[2].split('&')] | |
| coor4 = [int(i) for i in annotation[3].split('&')] | |
| coor = np.array([coor1, coor2, coor3, coor4]) | |
| center = np.mean(coor, axis=0) | |
| return coor, center.astype(int) | |
| #-----------------------------------# | |
| # 根据车牌坐标裁剪出车牌图像 | |
| #-----------------------------------# | |
| def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray: | |
| maxW = np.max(coor[:, 0] - center[0]) # max plate width | |
| maxH = np.max(coor[:, 1] - center[1]) # max plate height | |
| xWanted = [64, 128, 192, 256] | |
| yWanted = [32, 64, 96, 128] | |
| found = False | |
| for w, h in zip(xWanted, yWanted): | |
| if maxW < w//2 and maxH < h//2: | |
| maxH = h//2 | |
| maxW = w//2 | |
| found = True | |
| break | |
| if not found: # 车牌太大则丢弃 | |
| return np.array([]) | |
| elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \ | |
| center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]: | |
| return np.array([]) | |
| else: | |
| return image[center[1]-maxH:center[1]+maxH, center[0]-maxW:center[0]+maxW] | |
| #-----------------------------------# | |
| # 保存车牌图片 | |
| #-----------------------------------# | |
| def saveImage(image: np.ndarray, fileName: str, outDir: str) -> int: | |
| if image.shape[0] == 0: | |
| return 0 | |
| else: | |
| imgShape = image.shape | |
| if imgShape[1] == 64: | |
| mpimage.imsave(os.path.join(outDir, '64_32', fileName), image) | |
| elif imgShape[1] == 128: | |
| mpimage.imsave(os.path.join(outDir, '128_64', fileName), image) | |
| elif imgShape[1] == 208: | |
| mpimage.imsave(os.path.join(outDir, '192_96', fileName), image) | |
| else: #resize large images | |
| image = Image.fromarray(image).resize((192, 96)) | |
| image = np.asarray(image) # back to numpy array | |
| mpimage.imsave(os.path.join(outDir, '192_96', fileName), image) | |
| return 1 | |
| #-----------------------------------# | |
| # 包装成一个函数,以便将处理区分到不同目录 | |
| #-----------------------------------# | |
| def processImage(file: str, inputDir: str, outputDir: str, subFolder: str) -> int: | |
| result = parseLabel(file) | |
| filePath = os.path.join(inputDir,subFolder, file) | |
| image = readImage(filePath) | |
| plate = cropImage(image, result[0], result[1]) | |
| if plate.shape[0] == 0: | |
| return 0 | |
| mean = np.mean(plate/255.0) | |
| std = np.std(plate/255.0) | |
| # 亮度不好的 | |
| if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD: | |
| return 0 | |
| # 低对比度的 | |
| if std <= IMG_STD - 10*IMG_STD_STD: | |
| return 0 | |
| status = saveImage(plate, file, outputDir) | |
| return status | |
| def main(argv): | |
| jobNum = int(argv.jobNum) | |
| outputDir = argv.outputDir | |
| inputDir = argv.inputDir | |
| try: | |
| os.mkdir(outputDir) | |
| for shape in ['64_32', '128_64', '192_96']: | |
| os.mkdir(os.path.join(outputDir, shape)) | |
| except OSError: | |
| pass # 地址已经存在 | |
| client = LocalCluster(n_workers=jobNum, threads_per_worker=5) # 开启多线程 | |
| for subFolder in ['ccpd_base', 'ccpd_db', 'ccpd_fn', 'ccpd_rotate', 'ccpd_tilt', 'ccpd_weather']: | |
| fileList = os.listdir(os.path.join(inputDir, subFolder)) | |
| print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder)) | |
| toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory | |
| toDo = toDo.map(processImage, inputDir, outputDir, subFolder) | |
| pbar = ProgressBar(minimum=2.0) | |
| pbar.register() # 登记所有的计算,以便更好地跟踪 | |
| result = toDo.compute() | |
| print('* image cropped: {}. Done ...'.format(sum(result))) | |
| client.close() # 关闭集群 | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| add_arg = functools.partial(add_arguments, argparser=parser) | |
| add_arg('jobNum', int, 4, '处理图片的线程数') | |
| add_arg('inputDir', str, 'datasets/CCPD2019', '输入图片目录') | |
| add_arg('outputDir', str, 'datasets/CCPD2019_new', '保存图片目录') | |
| args = parser.parse_args() | |
| print_arguments(args) | |
| main(args) | |