Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 3,940 Bytes
			
			| 523fb10 b9be4e6 523fb10 b9be4e6 523fb10 b9be4e6 523fb10 b9be4e6 523fb10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import os
import cv2
import numpy as np
from PIL import Image
import glob
import torch
import tqdm
import shutil
import argparse
from third_party.GPEN.face_enhancement import FaceEnhancement
make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))
class GPENImageInfer(object):
    def __init__(self, device):
        super(GPENImageInfer, self).__init__()
        model = {
            "name": "GPEN-BFR-512",
            "in_size": 512,
            "out_size": 512,
            "channel_multiplier": 2,
            "narrow": 1,
        }
        faceenhancer = FaceEnhancement(
            base_dir=make_abs_path('./'),
            use_sr=True,
            in_size=model["in_size"],
            out_size=model["out_size"],
            model=model["name"],
            channel_multiplier=model["channel_multiplier"],
            narrow=model["narrow"],
            device=device,
        )
        self.faceenhancer = faceenhancer
    def image_infer(self, in_img: np.ndarray):
        """
        :param in_img: np.ndarray, (H,W,BGR), in [0,255]
        :return: out_img: np.ndarray, (H,W,BGR), in [0,255]
        """
        h, w, _ = in_img.shape
        out_img, orig_faces, enhanced_faces = self.faceenhancer.process(in_img)
        out_img = cv2.resize(out_img, (w, h))
        return out_img
    def ndarray_infer(self, in_ndarray: np.ndarray,
                      save_folder: str = 'demo_images/out/',
                      save_name: str = 'reen.png',
                      ):
        """
        :param in_ndarray: np.ndarray, (N,H,W,BGR), in [0,255]
        :param save_folder: not used
        :param save_name: not used
        :return: out_ndarray: np.ndarray, (N,H,W,BGR), in [0,255]
        """
        B, H, W, C = in_ndarray.shape
        out_ndarray = np.zeros_like(in_ndarray, dtype=np.uint8)  # (N,H,W,BGR)
        for b_idx in range(B):
            single_img = in_ndarray[b_idx]
            out_img = self.image_infer(single_img)  # (H,W,BGR), in [0,255]
            out_ndarray[b_idx] = out_img
        return out_ndarray
    def batch_infer(self, in_batch: torch.Tensor,
                          save_folder: str = 'demo_images/out/',
                          save_name: str = 'reen.png',
                          save_batch_idx: int = 0,
                          ):
        """
        :param in_batch: (N,RGB,H,W), in [-1,1]
        :return: out_batch: (N,RGB,H,W), in [-1,1]
        """
        B, C, H, W = in_batch.shape
        device = in_batch.device
        in_batch = ((in_batch + 1.) * 127.5).permute(0, 2, 3, 1)
        in_batch = in_batch.cpu().numpy().astype(np.uint8)  # (N,H,W,RGB), in [0,255]
        in_batch = in_batch[:, :, :, ::-1]  # (N,H,W,BGR)
        out_batch = np.zeros_like(in_batch, dtype=np.uint8)  # (N,H,W,BGR)
        for b_idx in range(B):
            single_img = in_batch[b_idx]
            out_img = self.image_infer(single_img)  # (H,W,BGR), in [0,255]
            out_batch[b_idx] = out_img[:, :, ::-1]
            if save_batch_idx is not None and b_idx == save_batch_idx:
                cv2.imwrite(os.path.join(save_folder, save_name), out_img)
        out_batch = torch.FloatTensor(out_batch).to(device)
        out_batch = out_batch / 127.5 - 1.  # (N,H,W,RGB)
        out_batch = out_batch.permute(0, 3, 1, 2)  # (N,RGB,H,W)
        out_batch = out_batch.clamp(-1, 1)
        return out_batch
if __name__ == '__main__':
    gpen = GPENImageInfer()
    in_folder = 'examples/imgs/'
    img_list = os.listdir(in_folder)
    for img_name in img_list:
        if 'gpen' in img_name:
            continue
        in_path = os.path.join(in_folder, img_name)
        out_path = in_path.replace('.png', '_gpen.png')
        out_path = in_path.replace('.jpg', '_gpen.jpg')
        im = cv2.imread(in_path, cv2.IMREAD_COLOR)  # BGR
        img = gpen.image_infer(im)
        cv2.imwrite(out_path, img)
 |