Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 2,826 Bytes
			
			| d380b77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | #!/usr/bin/env python3
import os
import random
import cv2
import numpy as np
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
from saicinpainting.evaluation.utils import load_yaml
from saicinpainting.training.visualizers.base import visualize_mask_and_images
def main(args):
    config = load_yaml(args.config)
    datasets = [PrecomputedInpaintingResultsDataset(args.datadir, cur_predictdir, **config.dataset_kwargs)
                for cur_predictdir in args.predictdirs]
    assert len({len(ds) for ds in datasets}) == 1
    len_first = len(datasets[0])
    indices = list(range(len_first))
    if len_first > args.max_n:
        indices = sorted(random.sample(indices, args.max_n))
    os.makedirs(args.outpath, exist_ok=True)
    filename2i = {}
    keys = ['image'] + [i for i in range(len(datasets))]
    for img_i in indices:
        try:
            mask_fname = os.path.basename(datasets[0].mask_filenames[img_i])
            if mask_fname in filename2i:
                filename2i[mask_fname] += 1
                idx = filename2i[mask_fname]
                mask_fname_only, ext = os.path.split(mask_fname)
                mask_fname = f'{mask_fname_only}_{idx}{ext}'
            else:
                filename2i[mask_fname] = 1
            cur_vis_dict = datasets[0][img_i]
            for ds_i, ds in enumerate(datasets):
                cur_vis_dict[ds_i] = ds[img_i]['inpainted']
            vis_img = visualize_mask_and_images(cur_vis_dict, keys,
                                                last_without_mask=False,
                                                mask_only_first=True,
                                                black_mask=args.black)
            vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
            out_fname = os.path.join(args.outpath, mask_fname)
            vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
            cv2.imwrite(out_fname, vis_img)
        except Exception as ex:
            print(f'Could not process {img_i} due to {ex}')
if __name__ == '__main__':
    import argparse
    aparser = argparse.ArgumentParser()
    aparser.add_argument('--max-n', type=int, default=100, help='Maximum number of images to print')
    aparser.add_argument('--black', action='store_true', help='Whether to fill mask on GT with black')
    aparser.add_argument('config', type=str, help='Path to evaluation config (e.g. configs/eval1.yaml)')
    aparser.add_argument('outpath', type=str, help='Where to put results')
    aparser.add_argument('datadir', type=str,
                         help='Path to folder with images and masks')
    aparser.add_argument('predictdirs', type=str,
                         nargs='+',
                         help='Path to folders with predicts')
    main(aparser.parse_args())
 |