Harisreedhar
		
	commited on
		
		
					Commit 
							
							·
						
						7f475d2
	
1
								Parent(s):
							
							27c3130
								
Add soft erosion and fix face parsing video
Browse files- app.py +40 -24
- face_parsing/__init__.py +1 -1
- face_parsing/swap.py +60 -17
- swapper.py +3 -3
    	
        app.py
    CHANGED
    
    | @@ -17,7 +17,7 @@ from moviepy.editor import VideoFileClip, ImageSequenceClip | |
| 17 |  | 
| 18 | 
             
            from face_analyser import detect_conditions, analyse_face
         | 
| 19 | 
             
            from utils import trim_video, StreamerThread, ProcessBar, open_directory
         | 
| 20 | 
            -
            from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list
         | 
| 21 | 
             
            from swapper import (
         | 
| 22 | 
             
                swap_face,
         | 
| 23 | 
             
                swap_face_with_condition,
         | 
| @@ -59,8 +59,9 @@ MASK_INCLUDE = [ | |
| 59 | 
             
                "L-Lip",
         | 
| 60 | 
             
                "U-Lip"
         | 
| 61 | 
             
            ]
         | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
|  | |
| 64 |  | 
| 65 | 
             
            FACE_SWAPPER = None
         | 
| 66 | 
             
            FACE_ANALYSER = None
         | 
| @@ -84,6 +85,8 @@ else: | |
| 84 | 
             
                USE_CUDA = False
         | 
| 85 | 
             
                print("\n********** Running on CPU **********\n")
         | 
| 86 |  | 
|  | |
|  | |
| 87 |  | 
| 88 | 
             
            ## ------------------------------ LOAD MODELS ------------------------------
         | 
| 89 |  | 
| @@ -114,7 +117,7 @@ def load_face_parser_model(name="./assets/pretrained_models/79999_iter.pth"): | |
| 114 | 
             
                global FACE_PARSER
         | 
| 115 | 
             
                path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
         | 
| 116 | 
             
                if FACE_PARSER is None:
         | 
| 117 | 
            -
                    FACE_PARSER = init_parser(name,  | 
| 118 |  | 
| 119 |  | 
| 120 | 
             
            load_face_analyser_model()
         | 
| @@ -137,9 +140,10 @@ def process( | |
| 137 | 
             
                distance,
         | 
| 138 | 
             
                face_enhance,
         | 
| 139 | 
             
                enable_face_parser,
         | 
| 140 | 
            -
                 | 
| 141 | 
            -
                 | 
| 142 | 
            -
                 | 
|  | |
| 143 | 
             
                *specifics,
         | 
| 144 | 
             
            ):
         | 
| 145 | 
             
                global WORKSPACE
         | 
| @@ -196,14 +200,18 @@ def process( | |
| 196 |  | 
| 197 | 
             
                yield "### \n ⌛ Analysing Face...", *ui_before()
         | 
| 198 |  | 
| 199 | 
            -
                 | 
| 200 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
| 201 | 
             
                models = {
         | 
| 202 | 
             
                    "swap": FACE_SWAPPER,
         | 
| 203 | 
             
                    "enhance": FACE_ENHANCER,
         | 
| 204 | 
             
                    "enhance_sett": face_enhance,
         | 
| 205 | 
             
                    "face_parser": FACE_PARSER,
         | 
| 206 | 
            -
                    "face_parser_sett": (enable_face_parser,  | 
| 207 | 
             
                }
         | 
| 208 |  | 
| 209 | 
             
                ## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
         | 
| @@ -301,9 +309,9 @@ def process( | |
| 301 |  | 
| 302 | 
             
                        if condition == "Specific Face":
         | 
| 303 | 
             
                            swapped = swap_specific(
         | 
| 304 | 
            -
                                frame,
         | 
| 305 | 
            -
                                analysed_target,
         | 
| 306 | 
             
                                analysed_source_specific,
         | 
|  | |
|  | |
| 307 | 
             
                                models,
         | 
| 308 | 
             
                                threshold=distance,
         | 
| 309 | 
             
                            )
         | 
| @@ -381,9 +389,9 @@ def process( | |
| 381 |  | 
| 382 | 
             
                        if condition == "Specific Face":
         | 
| 383 | 
             
                            swapped = swap_specific(
         | 
| 384 | 
            -
                                target,
         | 
| 385 | 
            -
                                analysed_target,
         | 
| 386 | 
             
                                analysed_source_specific,
         | 
|  | |
|  | |
| 387 | 
             
                                models,
         | 
| 388 | 
             
                                threshold=distance,
         | 
| 389 | 
             
                            )
         | 
| @@ -636,16 +644,23 @@ with gr.Blocks(css=css) as interface: | |
| 636 | 
             
                                        label="Include",
         | 
| 637 | 
             
                                        interactive=True,
         | 
| 638 | 
             
                                    )
         | 
| 639 | 
            -
                                     | 
| 640 | 
            -
                                         | 
| 641 | 
            -
                                        value= | 
| 642 | 
            -
                                         | 
| 643 | 
            -
                                        label="Exclude",
         | 
| 644 | 
             
                                        interactive=True,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 645 | 
             
                                    )
         | 
| 646 | 
            -
                                     | 
| 647 | 
            -
                                        label="Blur | 
| 648 | 
            -
                                        value= | 
| 649 | 
             
                                        minimum=0,
         | 
| 650 | 
             
                                        interactive=True,
         | 
| 651 | 
             
                                    )
         | 
| @@ -827,8 +842,9 @@ with gr.Blocks(css=css) as interface: | |
| 827 | 
             
                    enable_face_enhance,
         | 
| 828 | 
             
                    enable_face_parser_mask,
         | 
| 829 | 
             
                    mask_include,
         | 
| 830 | 
            -
                     | 
| 831 | 
            -
                     | 
|  | |
| 832 | 
             
                    *src_specific_inputs,
         | 
| 833 | 
             
                ]
         | 
| 834 |  | 
|  | |
| 17 |  | 
| 18 | 
             
            from face_analyser import detect_conditions, analyse_face
         | 
| 19 | 
             
            from utils import trim_video, StreamerThread, ProcessBar, open_directory
         | 
| 20 | 
            +
            from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
         | 
| 21 | 
             
            from swapper import (
         | 
| 22 | 
             
                swap_face,
         | 
| 23 | 
             
                swap_face_with_condition,
         | 
|  | |
| 59 | 
             
                "L-Lip",
         | 
| 60 | 
             
                "U-Lip"
         | 
| 61 | 
             
            ]
         | 
| 62 | 
            +
            MASK_SOFT_KERNEL = 17
         | 
| 63 | 
            +
            MASK_SOFT_ITERATIONS = 7
         | 
| 64 | 
            +
            MASK_BLUR_AMOUNT = 20
         | 
| 65 |  | 
| 66 | 
             
            FACE_SWAPPER = None
         | 
| 67 | 
             
            FACE_ANALYSER = None
         | 
|  | |
| 85 | 
             
                USE_CUDA = False
         | 
| 86 | 
             
                print("\n********** Running on CPU **********\n")
         | 
| 87 |  | 
| 88 | 
            +
            device = "cuda" if USE_CUDA else "cpu"
         | 
| 89 | 
            +
             | 
| 90 |  | 
| 91 | 
             
            ## ------------------------------ LOAD MODELS ------------------------------
         | 
| 92 |  | 
|  | |
| 117 | 
             
                global FACE_PARSER
         | 
| 118 | 
             
                path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
         | 
| 119 | 
             
                if FACE_PARSER is None:
         | 
| 120 | 
            +
                    FACE_PARSER = init_parser(name, mode=device)
         | 
| 121 |  | 
| 122 |  | 
| 123 | 
             
            load_face_analyser_model()
         | 
|  | |
| 140 | 
             
                distance,
         | 
| 141 | 
             
                face_enhance,
         | 
| 142 | 
             
                enable_face_parser,
         | 
| 143 | 
            +
                mask_includes,
         | 
| 144 | 
            +
                mask_soft_kernel,
         | 
| 145 | 
            +
                mask_soft_iterations,
         | 
| 146 | 
            +
                blur_amount,
         | 
| 147 | 
             
                *specifics,
         | 
| 148 | 
             
            ):
         | 
| 149 | 
             
                global WORKSPACE
         | 
|  | |
| 200 |  | 
| 201 | 
             
                yield "### \n ⌛ Analysing Face...", *ui_before()
         | 
| 202 |  | 
| 203 | 
            +
                includes = mask_regions_to_list(mask_includes)
         | 
| 204 | 
            +
                if mask_soft_iterations > 0:
         | 
| 205 | 
            +
                    smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=int(mask_soft_iterations)).to(device)
         | 
| 206 | 
            +
                else:
         | 
| 207 | 
            +
                    smooth_mask = None
         | 
| 208 | 
            +
             | 
| 209 | 
             
                models = {
         | 
| 210 | 
             
                    "swap": FACE_SWAPPER,
         | 
| 211 | 
             
                    "enhance": FACE_ENHANCER,
         | 
| 212 | 
             
                    "enhance_sett": face_enhance,
         | 
| 213 | 
             
                    "face_parser": FACE_PARSER,
         | 
| 214 | 
            +
                    "face_parser_sett": (enable_face_parser, includes, smooth_mask, int(blur_amount))
         | 
| 215 | 
             
                }
         | 
| 216 |  | 
| 217 | 
             
                ## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
         | 
|  | |
| 309 |  | 
| 310 | 
             
                        if condition == "Specific Face":
         | 
| 311 | 
             
                            swapped = swap_specific(
         | 
|  | |
|  | |
| 312 | 
             
                                analysed_source_specific,
         | 
| 313 | 
            +
                                analysed_target,
         | 
| 314 | 
            +
                                frame,
         | 
| 315 | 
             
                                models,
         | 
| 316 | 
             
                                threshold=distance,
         | 
| 317 | 
             
                            )
         | 
|  | |
| 389 |  | 
| 390 | 
             
                        if condition == "Specific Face":
         | 
| 391 | 
             
                            swapped = swap_specific(
         | 
|  | |
|  | |
| 392 | 
             
                                analysed_source_specific,
         | 
| 393 | 
            +
                                analysed_target,
         | 
| 394 | 
            +
                                target,
         | 
| 395 | 
             
                                models,
         | 
| 396 | 
             
                                threshold=distance,
         | 
| 397 | 
             
                            )
         | 
|  | |
| 644 | 
             
                                        label="Include",
         | 
| 645 | 
             
                                        interactive=True,
         | 
| 646 | 
             
                                    )
         | 
| 647 | 
            +
                                    mask_soft_kernel = gr.Number(
         | 
| 648 | 
            +
                                        label="Soft Erode Kernel",
         | 
| 649 | 
            +
                                        value=MASK_SOFT_KERNEL,
         | 
| 650 | 
            +
                                        minimum=3,
         | 
|  | |
| 651 | 
             
                                        interactive=True,
         | 
| 652 | 
            +
                                        visible = False
         | 
| 653 | 
            +
                                    )
         | 
| 654 | 
            +
                                    mask_soft_iterations = gr.Number(
         | 
| 655 | 
            +
                                        label="Soft Erode Iterations",
         | 
| 656 | 
            +
                                        value=MASK_SOFT_ITERATIONS,
         | 
| 657 | 
            +
                                        minimum=0,
         | 
| 658 | 
            +
                                        interactive=True,
         | 
| 659 | 
            +
             | 
| 660 | 
             
                                    )
         | 
| 661 | 
            +
                                    blur_amount = gr.Number(
         | 
| 662 | 
            +
                                        label="Mask Blur",
         | 
| 663 | 
            +
                                        value=MASK_BLUR_AMOUNT,
         | 
| 664 | 
             
                                        minimum=0,
         | 
| 665 | 
             
                                        interactive=True,
         | 
| 666 | 
             
                                    )
         | 
|  | |
| 842 | 
             
                    enable_face_enhance,
         | 
| 843 | 
             
                    enable_face_parser_mask,
         | 
| 844 | 
             
                    mask_include,
         | 
| 845 | 
            +
                    mask_soft_kernel,
         | 
| 846 | 
            +
                    mask_soft_iterations,
         | 
| 847 | 
            +
                    blur_amount,
         | 
| 848 | 
             
                    *src_specific_inputs,
         | 
| 849 | 
             
                ]
         | 
| 850 |  | 
    	
        face_parsing/__init__.py
    CHANGED
    
    | @@ -1 +1 @@ | |
| 1 | 
            -
            from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list
         | 
|  | |
| 1 | 
            +
            from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
         | 
    	
        face_parsing/swap.py
    CHANGED
    
    | @@ -1,4 +1,6 @@ | |
| 1 | 
             
            import torch
         | 
|  | |
|  | |
| 2 | 
             
            import torchvision.transforms as transforms
         | 
| 3 | 
             
            import cv2
         | 
| 4 | 
             
            import numpy as np
         | 
| @@ -27,15 +29,44 @@ mask_regions = { | |
| 27 | 
             
                "Hat":18
         | 
| 28 | 
             
            }
         | 
| 29 |  | 
| 30 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 |  | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 35 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 36 | 
             
                n_classes = 19
         | 
| 37 | 
             
                net = BiSeNet(n_classes=n_classes)
         | 
| 38 | 
            -
                if  | 
| 39 | 
             
                    net.cuda()
         | 
| 40 | 
             
                    net.load_state_dict(torch.load(pth_path))
         | 
| 41 | 
             
                else:
         | 
| @@ -55,8 +86,7 @@ def image_to_parsing(img, net): | |
| 55 | 
             
                img = torch.unsqueeze(img, 0)
         | 
| 56 |  | 
| 57 | 
             
                with torch.no_grad():
         | 
| 58 | 
            -
                     | 
| 59 | 
            -
                        img = img.cuda()
         | 
| 60 | 
             
                    out = net(img)[0]
         | 
| 61 | 
             
                    parsing = out.squeeze(0).cpu().numpy().argmax(0)
         | 
| 62 | 
             
                    return parsing
         | 
| @@ -68,20 +98,33 @@ def get_mask(parsing, classes): | |
| 68 | 
             
                    res += parsing == val
         | 
| 69 | 
             
                return res
         | 
| 70 |  | 
| 71 | 
            -
            def swap_regions(source, target, net, includes=[1,2,3,4,5,10,11,12,13],  | 
| 72 | 
             
                parsing = image_to_parsing(source, net)
         | 
|  | |
| 73 | 
             
                if len(includes) == 0:
         | 
| 74 | 
             
                    return source, np.zeros_like(source)
         | 
|  | |
| 75 | 
             
                include_mask = get_mask(parsing, includes)
         | 
| 76 | 
            -
                 | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
                     | 
| 80 | 
            -
                     | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 85 |  | 
| 86 | 
             
            def mask_regions_to_list(values):
         | 
| 87 | 
             
                out_ids = []
         | 
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
             
            import torchvision.transforms as transforms
         | 
| 5 | 
             
            import cv2
         | 
| 6 | 
             
            import numpy as np
         | 
|  | |
| 29 | 
             
                "Hat":18
         | 
| 30 | 
             
            }
         | 
| 31 |  | 
| 32 | 
            +
            # Borrowed from simswap
         | 
| 33 | 
            +
            # https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
         | 
| 34 | 
            +
            class SoftErosion(nn.Module):
         | 
| 35 | 
            +
                def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
         | 
| 36 | 
            +
                    super(SoftErosion, self).__init__()
         | 
| 37 | 
            +
                    r = kernel_size // 2
         | 
| 38 | 
            +
                    self.padding = r
         | 
| 39 | 
            +
                    self.iterations = iterations
         | 
| 40 | 
            +
                    self.threshold = threshold
         | 
| 41 |  | 
| 42 | 
            +
                    # Create kernel
         | 
| 43 | 
            +
                    y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
         | 
| 44 | 
            +
                    dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
         | 
| 45 | 
            +
                    kernel = dist.max() - dist
         | 
| 46 | 
            +
                    kernel /= kernel.sum()
         | 
| 47 | 
            +
                    kernel = kernel.view(1, 1, *kernel.shape)
         | 
| 48 | 
            +
                    self.register_buffer('weight', kernel)
         | 
| 49 |  | 
| 50 | 
            +
                def forward(self, x):
         | 
| 51 | 
            +
                    x = x.float()
         | 
| 52 | 
            +
                    for i in range(self.iterations - 1):
         | 
| 53 | 
            +
                        x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
         | 
| 54 | 
            +
                    x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    mask = x >= self.threshold
         | 
| 57 | 
            +
                    x[mask] = 1.0
         | 
| 58 | 
            +
                    x[~mask] /= x[~mask].max()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    return x, mask
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            device = "cpu"
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def init_parser(pth_path, mode="cpu"):
         | 
| 65 | 
            +
                global device
         | 
| 66 | 
            +
                device = mode
         | 
| 67 | 
             
                n_classes = 19
         | 
| 68 | 
             
                net = BiSeNet(n_classes=n_classes)
         | 
| 69 | 
            +
                if device == "cuda":
         | 
| 70 | 
             
                    net.cuda()
         | 
| 71 | 
             
                    net.load_state_dict(torch.load(pth_path))
         | 
| 72 | 
             
                else:
         | 
|  | |
| 86 | 
             
                img = torch.unsqueeze(img, 0)
         | 
| 87 |  | 
| 88 | 
             
                with torch.no_grad():
         | 
| 89 | 
            +
                    img = img.to(device)
         | 
|  | |
| 90 | 
             
                    out = net(img)[0]
         | 
| 91 | 
             
                    parsing = out.squeeze(0).cpu().numpy().argmax(0)
         | 
| 92 | 
             
                    return parsing
         | 
|  | |
| 98 | 
             
                    res += parsing == val
         | 
| 99 | 
             
                return res
         | 
| 100 |  | 
| 101 | 
            +
            def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
         | 
| 102 | 
             
                parsing = image_to_parsing(source, net)
         | 
| 103 | 
            +
             | 
| 104 | 
             
                if len(includes) == 0:
         | 
| 105 | 
             
                    return source, np.zeros_like(source)
         | 
| 106 | 
            +
             | 
| 107 | 
             
                include_mask = get_mask(parsing, includes)
         | 
| 108 | 
            +
                mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                if smooth_mask is not None:
         | 
| 111 | 
            +
                    mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
         | 
| 112 | 
            +
                    face_mask_tensor = mask_tensor[0] + mask_tensor[1]
         | 
| 113 | 
            +
                    soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
         | 
| 114 | 
            +
                    soft_face_mask_tensor.squeeze_()
         | 
| 115 | 
            +
                    mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                if blur > 0:
         | 
| 118 | 
            +
                    mask = cv2.GaussianBlur(mask, (0, 0), blur)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                resized_source = cv2.resize((source/255).astype("float32"), (512, 512))
         | 
| 121 | 
            +
                resized_target = cv2.resize((target/255).astype("float32"), (512, 512))
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                result = mask * resized_source + (1 - mask) * resized_target
         | 
| 124 | 
            +
                normalized_result = (result - np.min(result)) / (np.max(result) - np.min(result))
         | 
| 125 | 
            +
                result = cv2.resize((result*255).astype("uint8"), (source.shape[1], source.shape[0]))
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                return result
         | 
| 128 |  | 
| 129 | 
             
            def mask_regions_to_list(values):
         | 
| 130 | 
             
                out_ids = []
         | 
    	
        swapper.py
    CHANGED
    
    | @@ -25,10 +25,10 @@ def swap_face(whole_img, target_face, source_face, models): | |
| 25 | 
             
                aimg, _ = face_align.norm_crop2(whole_img, target_face.kps, image_size=image_size)
         | 
| 26 |  | 
| 27 | 
             
                if face_parser is not None:
         | 
| 28 | 
            -
                    fp_enable,  | 
| 29 | 
             
                    if fp_enable:
         | 
| 30 | 
            -
                        bgr_fake | 
| 31 | 
            -
                            bgr_fake, aimg, face_parser,  | 
| 32 | 
             
                        )
         | 
| 33 |  | 
| 34 | 
             
                if fe_enable:
         | 
|  | |
| 25 | 
             
                aimg, _ = face_align.norm_crop2(whole_img, target_face.kps, image_size=image_size)
         | 
| 26 |  | 
| 27 | 
             
                if face_parser is not None:
         | 
| 28 | 
            +
                    fp_enable, includes, smooth_mask, blur_amount = models.get("face_parser_sett")
         | 
| 29 | 
             
                    if fp_enable:
         | 
| 30 | 
            +
                        bgr_fake = swap_regions(
         | 
| 31 | 
            +
                            bgr_fake, aimg, face_parser, smooth_mask, includes=includes, blur=blur_amount
         | 
| 32 | 
             
                        )
         | 
| 33 |  | 
| 34 | 
             
                if fe_enable:
         | 
